{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6b439e84-8c6b-4126-8586-d7a1a7c7614e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 该代码文件主要为了在CIFAR10上执行相关实验，验证模型自蒸馏"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "3534cbe3-ca73-485f-b363-f64694bd1369",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 自蒸馏步骤：\n",
    "# 分割模型训练集，接收比例参数\n",
    "# 对训练集进行多次分割，多个影子数据集\n",
    "# 训练多个模型\n",
    "# 合并多个模型输出\n",
    "# 训练目标模型\n",
    "# 攻击\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "b61f101e-1c1e-41df-b80b-65d1e3d6eab8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from torch.utils.data import DataLoader\n",
    "from torch.utils.data import Dataset\n",
    "from torchvision import datasets\n",
    "from torchvision import transforms\n",
    "from torchvision.transforms import ToTensor\n",
    "import torchvision.transforms as tt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn import metrics\n",
    "\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "37ddaa77-35ce-49b9-acfd-a7799aadd9a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 导入自己创建的python文件\n",
    "import sys\n",
    "sys.path.append(\"..\") # Adds higher directory to python modules path.\n",
    "from frame.DataProcess import *\n",
    "from frame.TrainUtil import *\n",
    "from frame.LIRAAttack import *\n",
    "from frame.AttackUtil import *\n",
    "from frame.ShadowAttack import *\n",
    "from frame.ThresholdAttack import *\n",
    "from frame.LabelAttack import *\n",
    "from frame.ModelUtil import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "4636a18e-244e-4a21-ba21-591c0295ce7a",
   "metadata": {},
   "outputs": [],
   "source": [
    "LEARNING_RATE = 1e-1\n",
    "BATCH_SIZE = 128\n",
    "MODEL = 'ResNet18'\n",
    "EPOCHS = 100\n",
    "DATA_NAME = 'CIFAR100' \n",
    "weight_dir = os.path.join('..', 'weights_for_exp', DATA_NAME)\n",
    "num_shadowsets = 100\n",
    "seed = 0\n",
    "prop_keep = 0.5\n",
    "\n",
    "model_transform = transforms.Compose([\n",
    "    transforms.ToPILImage(),\n",
    "    transforms.RandomCrop(32, padding=4),  #先四周填充0，在吧图像随机裁剪成32*32\n",
    "    transforms.RandomHorizontalFlip(),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])\n",
    "    ])\n",
    "model_transform_test = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])\n",
    "])\n",
    "\n",
    "attack_transform = transforms.Compose([])\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "# 影子模型攻击相关参数\n",
    "sha_models = [2,3,4] #[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30]\n",
    "tar_model = 0\n",
    "attack_class = False #是否针对每个类别分别攻击\n",
    "attack_lr = 5e-4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "e8818f5c-0cf8-4635-a3d6-292f91f14d75",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "# 加载完整的训练数据集\n",
    "X_data, Y_data, train_keep = load_CIFAR100_keep(num_shadowsets, prop_keep, seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "63ef244d-acb8-47a2-8af6-0ee2329408fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_data = CustomDataset(X_data, Y_data, model_transform_test)\n",
    "all_dataloader = DataLoader(all_data, batch_size=BATCH_SIZE, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "7ff79cce-9f89-46a9-b04a-8f716bb66540",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = BATCH_SIZE\n",
    "model = MODEL\n",
    "epochs = EPOCHS\n",
    "data_name = DATA_NAME \n",
    "weight_part = \"{}_{}_epoch{}_shadownum100_model\".format(data_name, model, epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "42098920-5c02-4007-9bb8-bd65a73344e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_data_all = np.load('../outputs_save/CIFAR100_resnet_loss.npy')\n",
    "score_all = np.load('../outputs_save/CIFAR100_resnet_score.npy')\n",
    "conf_data_all = np.load('../outputs_save/CIFAR100_resnet_conf.npy')\n",
    "pri_risk_all = get_risk_score(loss_data_all, train_keep)\n",
    "pri_risk_rank = np.argsort(pri_risk_all)\n",
    "pri_risk_rank = np.flip(pri_risk_rank)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a80eac30-31c4-4e90-a6e8-220a7c2a1989",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "40cf7e11-de44-4cc6-9c47-a8225c8e7d95",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "(24869, 32, 32, 3) (24869,) (25131, 32, 32, 3) (25131,)\n",
      " Error: \n",
      " Accuracy: 99.9%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 64.3%  \n",
      "\n",
      "(50000, 100) (50000,) (50000,)\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "(25028, 32, 32, 3) (25028,) (24972, 32, 32, 3) (24972,)\n",
      " Error: \n",
      " Accuracy: 99.8%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 63.3%  \n",
      "\n",
      "(50000, 100) (50000,) (50000,)\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "(25126, 32, 32, 3) (25126,) (24874, 32, 32, 3) (24874,)\n",
      " Error: \n",
      " Accuracy: 99.8%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 63.3%  \n",
      "\n",
      "(50000, 100) (50000,) (50000,)\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "(25020, 32, 32, 3) (25020,) (24980, 32, 32, 3) (24980,)\n",
      " Error: \n",
      " Accuracy: 99.8%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 63.6%  \n",
      "\n",
      "test data: (50000, 100) (50000,) (50000,)\n",
      "(150000, 100) (150000,)\n",
      "Attack_NN(\n",
      "  (linear_relu_stack): Sequential(\n",
      "    (0): Linear(in_features=4, out_features=128, bias=True)\n",
      "    (1): ReLU()\n",
      "    (2): Linear(in_features=128, out_features=64, bias=True)\n",
      "    (3): ReLU()\n",
      "    (4): Linear(in_features=64, out_features=1, bias=True)\n",
      "  )\n",
      ")\n",
      "Epoch 1\n",
      "-------------------------------\n",
      "loss: 0.697224  [  128/150000]\n",
      "loss: 0.554047  [12928/150000]\n",
      "loss: 0.512395  [25728/150000]\n",
      "loss: 0.474868  [38528/150000]\n",
      "loss: 0.559360  [51328/150000]\n",
      "loss: 0.455311  [64128/150000]\n",
      "loss: 0.484935  [76928/150000]\n",
      "loss: 0.489729  [89728/150000]\n",
      "loss: 0.482427  [102528/150000]\n",
      "loss: 0.427475  [115328/150000]\n",
      "loss: 0.440608  [128128/150000]\n",
      "loss: 0.583961  [140928/150000]\n",
      "Epoch 2\n",
      "-------------------------------\n",
      "loss: 0.544936  [  128/150000]\n",
      "loss: 0.421262  [12928/150000]\n",
      "loss: 0.446645  [25728/150000]\n",
      "loss: 0.510597  [38528/150000]\n",
      "loss: 0.462909  [51328/150000]\n",
      "loss: 0.493341  [64128/150000]\n",
      "loss: 0.481638  [76928/150000]\n",
      "loss: 0.474367  [89728/150000]\n",
      "loss: 0.509177  [102528/150000]\n",
      "loss: 0.529549  [115328/150000]\n",
      "loss: 0.507267  [128128/150000]\n",
      "loss: 0.490833  [140928/150000]\n",
      "Epoch 3\n",
      "-------------------------------\n",
      "loss: 0.479290  [  128/150000]\n",
      "loss: 0.428154  [12928/150000]\n",
      "loss: 0.478899  [25728/150000]\n",
      "loss: 0.488318  [38528/150000]\n",
      "loss: 0.602147  [51328/150000]\n",
      "loss: 0.533132  [64128/150000]\n",
      "loss: 0.510031  [76928/150000]\n",
      "loss: 0.469753  [89728/150000]\n",
      "loss: 0.466672  [102528/150000]\n",
      "loss: 0.405882  [115328/150000]\n",
      "loss: 0.480704  [128128/150000]\n",
      "loss: 0.501340  [140928/150000]\n",
      "Epoch 4\n",
      "-------------------------------\n",
      "loss: 0.478872  [  128/150000]\n",
      "loss: 0.455217  [12928/150000]\n",
      "loss: 0.468596  [25728/150000]\n",
      "loss: 0.398414  [38528/150000]\n",
      "loss: 0.475869  [51328/150000]\n",
      "loss: 0.473554  [64128/150000]\n",
      "loss: 0.479793  [76928/150000]\n",
      "loss: 0.514152  [89728/150000]\n",
      "loss: 0.447902  [102528/150000]\n",
      "loss: 0.464730  [115328/150000]\n",
      "loss: 0.489643  [128128/150000]\n",
      "loss: 0.494449  [140928/150000]\n",
      "Epoch 5\n",
      "-------------------------------\n",
      "loss: 0.429394  [  128/150000]\n",
      "loss: 0.493888  [12928/150000]\n",
      "loss: 0.549939  [25728/150000]\n",
      "loss: 0.454794  [38528/150000]\n",
      "loss: 0.461338  [51328/150000]\n",
      "loss: 0.367450  [64128/150000]\n",
      "loss: 0.507288  [76928/150000]\n",
      "loss: 0.474333  [89728/150000]\n",
      "loss: 0.526260  [102528/150000]\n",
      "loss: 0.491090  [115328/150000]\n",
      "loss: 0.444966  [128128/150000]\n",
      "loss: 0.472245  [140928/150000]\n",
      "Epoch 6\n",
      "-------------------------------\n",
      "loss: 0.447177  [  128/150000]\n",
      "loss: 0.467791  [12928/150000]\n",
      "loss: 0.439126  [25728/150000]\n",
      "loss: 0.495911  [38528/150000]\n",
      "loss: 0.404101  [51328/150000]\n",
      "loss: 0.483073  [64128/150000]\n",
      "loss: 0.411588  [76928/150000]\n",
      "loss: 0.513404  [89728/150000]\n",
      "loss: 0.490935  [102528/150000]\n",
      "loss: 0.448009  [115328/150000]\n",
      "loss: 0.488673  [128128/150000]\n",
      "loss: 0.497216  [140928/150000]\n",
      "Epoch 7\n",
      "-------------------------------\n",
      "loss: 0.506236  [  128/150000]\n",
      "loss: 0.540286  [12928/150000]\n",
      "loss: 0.447288  [25728/150000]\n",
      "loss: 0.464637  [38528/150000]\n",
      "loss: 0.424826  [51328/150000]\n",
      "loss: 0.452686  [64128/150000]\n",
      "loss: 0.495223  [76928/150000]\n",
      "loss: 0.494302  [89728/150000]\n",
      "loss: 0.530311  [102528/150000]\n",
      "loss: 0.494813  [115328/150000]\n",
      "loss: 0.362646  [128128/150000]\n",
      "loss: 0.439862  [140928/150000]\n",
      "Epoch 8\n",
      "-------------------------------\n",
      "loss: 0.488929  [  128/150000]\n",
      "loss: 0.495009  [12928/150000]\n",
      "loss: 0.532610  [25728/150000]\n",
      "loss: 0.486136  [38528/150000]\n",
      "loss: 0.513144  [51328/150000]\n",
      "loss: 0.409610  [64128/150000]\n",
      "loss: 0.463498  [76928/150000]\n",
      "loss: 0.565585  [89728/150000]\n",
      "loss: 0.465556  [102528/150000]\n",
      "loss: 0.514595  [115328/150000]\n",
      "loss: 0.481498  [128128/150000]\n",
      "loss: 0.567436  [140928/150000]\n",
      "Epoch 9\n",
      "-------------------------------\n",
      "loss: 0.450865  [  128/150000]\n",
      "loss: 0.508641  [12928/150000]\n",
      "loss: 0.446823  [25728/150000]\n",
      "loss: 0.409848  [38528/150000]\n",
      "loss: 0.497180  [51328/150000]\n",
      "loss: 0.400605  [64128/150000]\n",
      "loss: 0.478868  [76928/150000]\n",
      "loss: 0.457959  [89728/150000]\n",
      "loss: 0.425989  [102528/150000]\n",
      "loss: 0.524350  [115328/150000]\n",
      "loss: 0.472402  [128128/150000]\n",
      "loss: 0.493299  [140928/150000]\n",
      "Epoch 10\n",
      "-------------------------------\n",
      "loss: 0.485928  [  128/150000]\n",
      "loss: 0.550855  [12928/150000]\n",
      "loss: 0.444851  [25728/150000]\n",
      "loss: 0.453331  [38528/150000]\n",
      "loss: 0.455764  [51328/150000]\n",
      "loss: 0.559042  [64128/150000]\n",
      "loss: 0.524302  [76928/150000]\n",
      "loss: 0.484624  [89728/150000]\n",
      "loss: 0.432667  [102528/150000]\n",
      "loss: 0.523299  [115328/150000]\n",
      "loss: 0.502926  [128128/150000]\n",
      "loss: 0.455255  [140928/150000]\n",
      "Epoch 11\n",
      "-------------------------------\n",
      "loss: 0.470769  [  128/150000]\n",
      "loss: 0.513126  [12928/150000]\n",
      "loss: 0.539786  [25728/150000]\n",
      "loss: 0.462196  [38528/150000]\n",
      "loss: 0.458585  [51328/150000]\n",
      "loss: 0.480692  [64128/150000]\n",
      "loss: 0.453237  [76928/150000]\n",
      "loss: 0.437583  [89728/150000]\n",
      "loss: 0.501007  [102528/150000]\n",
      "loss: 0.473877  [115328/150000]\n",
      "loss: 0.464201  [128128/150000]\n",
      "loss: 0.495804  [140928/150000]\n",
      "Epoch 12\n",
      "-------------------------------\n",
      "loss: 0.485485  [  128/150000]\n",
      "loss: 0.479002  [12928/150000]\n",
      "loss: 0.446548  [25728/150000]\n",
      "loss: 0.550057  [38528/150000]\n",
      "loss: 0.504658  [51328/150000]\n",
      "loss: 0.419189  [64128/150000]\n",
      "loss: 0.419141  [76928/150000]\n",
      "loss: 0.457962  [89728/150000]\n",
      "loss: 0.511034  [102528/150000]\n",
      "loss: 0.470289  [115328/150000]\n",
      "loss: 0.511620  [128128/150000]\n",
      "loss: 0.419199  [140928/150000]\n",
      "Epoch 13\n",
      "-------------------------------\n",
      "loss: 0.461920  [  128/150000]\n",
      "loss: 0.513022  [12928/150000]\n",
      "loss: 0.568752  [25728/150000]\n",
      "loss: 0.544514  [38528/150000]\n",
      "loss: 0.479188  [51328/150000]\n",
      "loss: 0.418225  [64128/150000]\n",
      "loss: 0.414616  [76928/150000]\n",
      "loss: 0.441753  [89728/150000]\n",
      "loss: 0.468708  [102528/150000]\n",
      "loss: 0.504753  [115328/150000]\n",
      "loss: 0.471287  [128128/150000]\n",
      "loss: 0.475969  [140928/150000]\n",
      "Epoch 14\n",
      "-------------------------------\n",
      "loss: 0.552783  [  128/150000]\n",
      "loss: 0.469312  [12928/150000]\n",
      "loss: 0.414544  [25728/150000]\n",
      "loss: 0.478464  [38528/150000]\n",
      "loss: 0.388563  [51328/150000]\n",
      "loss: 0.427832  [64128/150000]\n",
      "loss: 0.502035  [76928/150000]\n",
      "loss: 0.513928  [89728/150000]\n",
      "loss: 0.490727  [102528/150000]\n",
      "loss: 0.404948  [115328/150000]\n",
      "loss: 0.481431  [128128/150000]\n",
      "loss: 0.482584  [140928/150000]\n",
      "Epoch 15\n",
      "-------------------------------\n",
      "loss: 0.466074  [  128/150000]\n",
      "loss: 0.400611  [12928/150000]\n",
      "loss: 0.428508  [25728/150000]\n",
      "loss: 0.456591  [38528/150000]\n",
      "loss: 0.427192  [51328/150000]\n",
      "loss: 0.546118  [64128/150000]\n",
      "loss: 0.505553  [76928/150000]\n",
      "loss: 0.464037  [89728/150000]\n",
      "loss: 0.391572  [102528/150000]\n",
      "loss: 0.517928  [115328/150000]\n",
      "loss: 0.532692  [128128/150000]\n",
      "loss: 0.481347  [140928/150000]\n",
      "Epoch 16\n",
      "-------------------------------\n",
      "loss: 0.465851  [  128/150000]\n",
      "loss: 0.463738  [12928/150000]\n",
      "loss: 0.493814  [25728/150000]\n",
      "loss: 0.544868  [38528/150000]\n",
      "loss: 0.456342  [51328/150000]\n",
      "loss: 0.484545  [64128/150000]\n",
      "loss: 0.438380  [76928/150000]\n",
      "loss: 0.519688  [89728/150000]\n",
      "loss: 0.408633  [102528/150000]\n",
      "loss: 0.497902  [115328/150000]\n",
      "loss: 0.539646  [128128/150000]\n",
      "loss: 0.480615  [140928/150000]\n",
      "Epoch 17\n",
      "-------------------------------\n",
      "loss: 0.478514  [  128/150000]\n",
      "loss: 0.468891  [12928/150000]\n",
      "loss: 0.476602  [25728/150000]\n",
      "loss: 0.563772  [38528/150000]\n",
      "loss: 0.492627  [51328/150000]\n",
      "loss: 0.458670  [64128/150000]\n",
      "loss: 0.554158  [76928/150000]\n",
      "loss: 0.390834  [89728/150000]\n",
      "loss: 0.470684  [102528/150000]\n",
      "loss: 0.441032  [115328/150000]\n",
      "loss: 0.544835  [128128/150000]\n",
      "loss: 0.463647  [140928/150000]\n",
      "Epoch 18\n",
      "-------------------------------\n",
      "loss: 0.508415  [  128/150000]\n",
      "loss: 0.425311  [12928/150000]\n",
      "loss: 0.491953  [25728/150000]\n",
      "loss: 0.417318  [38528/150000]\n",
      "loss: 0.550511  [51328/150000]\n",
      "loss: 0.441897  [64128/150000]\n",
      "loss: 0.560107  [76928/150000]\n",
      "loss: 0.555201  [89728/150000]\n",
      "loss: 0.459037  [102528/150000]\n",
      "loss: 0.498597  [115328/150000]\n",
      "loss: 0.466898  [128128/150000]\n",
      "loss: 0.491927  [140928/150000]\n",
      "Epoch 19\n",
      "-------------------------------\n",
      "loss: 0.468594  [  128/150000]\n",
      "loss: 0.483828  [12928/150000]\n",
      "loss: 0.529592  [25728/150000]\n",
      "loss: 0.447286  [38528/150000]\n",
      "loss: 0.429376  [51328/150000]\n",
      "loss: 0.478336  [64128/150000]\n",
      "loss: 0.474628  [76928/150000]\n",
      "loss: 0.538299  [89728/150000]\n",
      "loss: 0.481607  [102528/150000]\n",
      "loss: 0.467459  [115328/150000]\n",
      "loss: 0.491432  [128128/150000]\n",
      "loss: 0.450827  [140928/150000]\n",
      "Epoch 20\n",
      "-------------------------------\n",
      "loss: 0.433995  [  128/150000]\n",
      "loss: 0.560161  [12928/150000]\n",
      "loss: 0.446981  [25728/150000]\n",
      "loss: 0.430724  [38528/150000]\n",
      "loss: 0.529422  [51328/150000]\n",
      "loss: 0.561174  [64128/150000]\n",
      "loss: 0.470982  [76928/150000]\n",
      "loss: 0.442228  [89728/150000]\n",
      "loss: 0.468507  [102528/150000]\n",
      "loss: 0.492143  [115328/150000]\n",
      "loss: 0.405608  [128128/150000]\n",
      "loss: 0.556653  [140928/150000]\n",
      "Epoch 21\n",
      "-------------------------------\n",
      "loss: 0.493365  [  128/150000]\n",
      "loss: 0.459675  [12928/150000]\n",
      "loss: 0.506240  [25728/150000]\n",
      "loss: 0.462550  [38528/150000]\n",
      "loss: 0.419504  [51328/150000]\n",
      "loss: 0.467064  [64128/150000]\n",
      "loss: 0.487499  [76928/150000]\n",
      "loss: 0.555912  [89728/150000]\n",
      "loss: 0.481177  [102528/150000]\n",
      "loss: 0.475811  [115328/150000]\n",
      "loss: 0.416698  [128128/150000]\n",
      "loss: 0.519363  [140928/150000]\n",
      "Epoch 22\n",
      "-------------------------------\n",
      "loss: 0.533681  [  128/150000]\n",
      "loss: 0.548289  [12928/150000]\n",
      "loss: 0.461269  [25728/150000]\n",
      "loss: 0.494138  [38528/150000]\n",
      "loss: 0.447719  [51328/150000]\n",
      "loss: 0.521976  [64128/150000]\n",
      "loss: 0.450560  [76928/150000]\n",
      "loss: 0.465617  [89728/150000]\n",
      "loss: 0.443076  [102528/150000]\n",
      "loss: 0.480434  [115328/150000]\n",
      "loss: 0.491328  [128128/150000]\n",
      "loss: 0.449807  [140928/150000]\n",
      "Epoch 23\n",
      "-------------------------------\n",
      "loss: 0.421084  [  128/150000]\n",
      "loss: 0.507462  [12928/150000]\n",
      "loss: 0.453519  [25728/150000]\n",
      "loss: 0.465675  [38528/150000]\n",
      "loss: 0.465741  [51328/150000]\n",
      "loss: 0.503898  [64128/150000]\n",
      "loss: 0.487370  [76928/150000]\n",
      "loss: 0.444438  [89728/150000]\n",
      "loss: 0.587351  [102528/150000]\n",
      "loss: 0.466783  [115328/150000]\n",
      "loss: 0.471341  [128128/150000]\n",
      "loss: 0.528918  [140928/150000]\n",
      "Epoch 24\n",
      "-------------------------------\n",
      "loss: 0.425281  [  128/150000]\n",
      "loss: 0.473858  [12928/150000]\n",
      "loss: 0.477565  [25728/150000]\n",
      "loss: 0.497431  [38528/150000]\n",
      "loss: 0.436030  [51328/150000]\n",
      "loss: 0.488753  [64128/150000]\n",
      "loss: 0.470055  [76928/150000]\n",
      "loss: 0.477952  [89728/150000]\n",
      "loss: 0.493826  [102528/150000]\n",
      "loss: 0.498997  [115328/150000]\n",
      "loss: 0.436829  [128128/150000]\n",
      "loss: 0.506279  [140928/150000]\n",
      "Epoch 25\n",
      "-------------------------------\n",
      "loss: 0.571877  [  128/150000]\n",
      "loss: 0.503094  [12928/150000]\n",
      "loss: 0.416119  [25728/150000]\n",
      "loss: 0.544354  [38528/150000]\n",
      "loss: 0.433010  [51328/150000]\n",
      "loss: 0.506473  [64128/150000]\n",
      "loss: 0.435116  [76928/150000]\n",
      "loss: 0.595160  [89728/150000]\n",
      "loss: 0.475158  [102528/150000]\n",
      "loss: 0.517538  [115328/150000]\n",
      "loss: 0.519553  [128128/150000]\n",
      "loss: 0.482932  [140928/150000]\n",
      "Epoch 26\n",
      "-------------------------------\n",
      "loss: 0.479654  [  128/150000]\n",
      "loss: 0.602124  [12928/150000]\n",
      "loss: 0.523316  [25728/150000]\n",
      "loss: 0.519923  [38528/150000]\n",
      "loss: 0.585293  [51328/150000]\n",
      "loss: 0.468458  [64128/150000]\n",
      "loss: 0.619206  [76928/150000]\n",
      "loss: 0.526383  [89728/150000]\n",
      "loss: 0.461524  [102528/150000]\n",
      "loss: 0.406550  [115328/150000]\n",
      "loss: 0.505335  [128128/150000]\n",
      "loss: 0.473558  [140928/150000]\n",
      "Epoch 27\n",
      "-------------------------------\n",
      "loss: 0.425132  [  128/150000]\n",
      "loss: 0.465348  [12928/150000]\n",
      "loss: 0.486645  [25728/150000]\n",
      "loss: 0.519760  [38528/150000]\n",
      "loss: 0.506847  [51328/150000]\n",
      "loss: 0.541967  [64128/150000]\n",
      "loss: 0.520506  [76928/150000]\n",
      "loss: 0.486822  [89728/150000]\n",
      "loss: 0.476903  [102528/150000]\n",
      "loss: 0.429918  [115328/150000]\n",
      "loss: 0.512524  [128128/150000]\n",
      "loss: 0.433281  [140928/150000]\n",
      "Epoch 28\n",
      "-------------------------------\n",
      "loss: 0.467590  [  128/150000]\n",
      "loss: 0.514641  [12928/150000]\n",
      "loss: 0.504446  [25728/150000]\n",
      "loss: 0.500821  [38528/150000]\n",
      "loss: 0.481756  [51328/150000]\n",
      "loss: 0.464246  [64128/150000]\n",
      "loss: 0.427159  [76928/150000]\n",
      "loss: 0.486407  [89728/150000]\n",
      "loss: 0.509579  [102528/150000]\n",
      "loss: 0.486120  [115328/150000]\n",
      "loss: 0.508524  [128128/150000]\n",
      "loss: 0.442165  [140928/150000]\n",
      "Epoch 29\n",
      "-------------------------------\n",
      "loss: 0.469828  [  128/150000]\n",
      "loss: 0.438455  [12928/150000]\n",
      "loss: 0.453277  [25728/150000]\n",
      "loss: 0.483211  [38528/150000]\n",
      "loss: 0.468512  [51328/150000]\n",
      "loss: 0.514062  [64128/150000]\n",
      "loss: 0.441799  [76928/150000]\n",
      "loss: 0.418952  [89728/150000]\n",
      "loss: 0.557525  [102528/150000]\n",
      "loss: 0.461728  [115328/150000]\n",
      "loss: 0.498426  [128128/150000]\n",
      "loss: 0.522774  [140928/150000]\n",
      "Epoch 30\n",
      "-------------------------------\n",
      "loss: 0.557969  [  128/150000]\n",
      "loss: 0.408953  [12928/150000]\n",
      "loss: 0.525713  [25728/150000]\n",
      "loss: 0.485438  [38528/150000]\n",
      "loss: 0.469902  [51328/150000]\n",
      "loss: 0.440189  [64128/150000]\n",
      "loss: 0.569537  [76928/150000]\n",
      "loss: 0.550817  [89728/150000]\n",
      "loss: 0.383593  [102528/150000]\n",
      "loss: 0.489277  [115328/150000]\n",
      "loss: 0.457965  [128128/150000]\n",
      "loss: 0.505174  [140928/150000]\n",
      "Done!\n",
      "Train data:\n",
      "AUC value is: 0.7762425696235233\n",
      "Accuracy is: 0.7616133333333334\n",
      "Test data:\n",
      "AUC value is: 0.7760501814721161\n",
      "Accuracy is: 0.7635\n"
     ]
    }
   ],
   "source": [
    "# 训练影子攻击模型\n",
    "attack_model = shadow_attack(sha_models=sha_models, tar_model=tar_model, model_num=num_shadowsets, weight_dir=weight_dir, data_name=DATA_NAME, model=MODEL, model_transform=model_transform_test, \n",
    "                  model_epochs=EPOCHS, batch_size=BATCH_SIZE, learning_rate=attack_lr, attack_epochs=30, attack_transform=attack_transform, \n",
    "                  device=device, prop_keep=0.5, top_k=3, attack_class=attack_class)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4cdd611a-1cd1-4b60-a4da-b7d83ede3229",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "68040cc0-4838-4ec0-8273-36e3a7788946",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 指定蒸馏手段的目标数据集\n",
    "mem_label = train_keep[0]\n",
    "mem_data = np.where(mem_label==True)[0]\n",
    "train_num = mem_label.sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "3b6c6b19-d10e-42b2-a450-c20d342c7944",
   "metadata": {},
   "outputs": [],
   "source": [
    "dist_num = 50\n",
    "np.random.seed(seed)\n",
    "keep_matrix = np.random.uniform(0,1,size=(dist_num, train_num))\n",
    "order = keep_matrix.argsort(0)\n",
    "dist_keep = order < int(prop_keep * dist_num)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "1cfbed3b-e970-4b18-919d-47f5f708bdde",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = X_data[mem_label]\n",
    "y = Y_data[mem_label]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "c3c99aba-ff2c-49a3-b80f-745b6a6f77ae",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# # 加载目标数据在所有蒸馏数据上的输出\n",
    "# weight_part = \"{}_{}_epoch{}_dist0_model\".format(data_name, model, epochs)\n",
    "# conf_data_train, label_data_train, _ = load_score_data_all(x, y, weight_dir, dist_num, data_name, model, weight_part, model_transform, batch_size, device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "48f95b18-ac13-4ecc-be30-4abbd2cd44c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# conf_in = []\n",
    "# conf_out = []\n",
    "# for i in range(conf_data_train.shape[1]):\n",
    "#     conf_in.append((conf_data_train[dist_keep[:,i],i]))\n",
    "#     conf_out.append((conf_data_train[~dist_keep[:,i],i]))\n",
    "# conf_in = np.array(conf_in)\n",
    "# conf_out = np.array(conf_out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "2fff78bf-465a-4653-874b-75d76228444d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# conf_in_mean = np.median(conf_in, 1)\n",
    "# conf_out_mean = np.median(conf_out, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "70a52e9f-fe64-4da3-9039-ebfb9a38a847",
   "metadata": {},
   "outputs": [],
   "source": [
    "# np.save('../outputs_save/CIFAR100_resnet_conf_in_mean.npy', conf_in_mean)\n",
    "# np.save('../outputs_save/CIFAR100_resnet_conf_out_mean.npy', conf_out_mean)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "65ccbc82-a0af-42db-8c49-0ffabe7ebbc5",
   "metadata": {},
   "outputs": [],
   "source": [
    "conf_in_mean = np.load('../outputs_save/CIFAR100_resnet_conf_in_mean.npy')\n",
    "conf_out_mean = np.load('../outputs_save/CIFAR100_resnet_conf_out_mean.npy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "1cc0e7fc-6f95-4d49-b23f-b42f193c79d5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(50000,)"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pri_risk_all.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "a2d58cef-fee7-48d2-b671-f875b5f66204",
   "metadata": {},
   "outputs": [],
   "source": [
    "y_onehot = np.eye(100, dtype=np.float64)[Y_data[mem_label]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "19a7db60-3c3c-453f-a36d-c66a08989e38",
   "metadata": {},
   "outputs": [],
   "source": [
    "def cross_entropy(y_pred, y_true):\n",
    "    ce = -y_true*np.log(y_pred+1e-30)\n",
    "    ce = np.sum(ce, axis=0)\n",
    "    return ce\n",
    "\n",
    "def cal_risk(conf_in, conf_out, y_true):\n",
    "    loss_in = cross_entropy(conf_in, y_true)\n",
    "    loss_out = cross_entropy(conf_out, y_true)\n",
    "    risk = loss_out - loss_in\n",
    "    if risk < 0:\n",
    "        risk = -risk\n",
    "    return risk\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "4ddb95e1-69e4-4447-abca-b965a0f63e41",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "test_dataset = datasets.cifar.CIFAR100(root='../datasets/cifar100', train=False, transform=None, download=True)\n",
    "x_test_data = test_dataset.data\n",
    "y_test_data = np.array(test_dataset.targets)\n",
    "test_data = CustomDataset(x_test_data, y_test_data, model_transform_test)\n",
    "test_dataloader = DataLoader(test_data, batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "3f46af4b-dc23-45fc-a1e8-4d98fab647d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def label_fix(risk_bound, k_init, y_onehot, conf_out_mean, pri_risk_all, mem_data):\n",
    "    y_soft = y_onehot.copy()\n",
    "    for i in range(y_onehot.shape[0]):\n",
    "        risk = pri_risk_all[mem_data[i]]\n",
    "        if risk < risk_bound:\n",
    "            k = k_init\n",
    "            y_soft[i] = k*y_onehot[i] + (1-k)*conf_out_mean[i]\n",
    "        else:\n",
    "            k = k_init\n",
    "            soft_label = k*y_onehot[i] + (1-k)*conf_out_mean[i]\n",
    "            risk_thre = risk_bound\n",
    "            risk = cal_risk(soft_label, conf_out_mean[i], y_onehot[i])\n",
    "            while(risk>risk_thre and k>0.01):\n",
    "                k -= 0.01\n",
    "                soft_label = k*y_onehot[i] + (1-k)*conf_out_mean[i]\n",
    "                risk = cal_risk(soft_label, conf_out_mean[i], y_onehot[i])\n",
    "            y_soft[i] = soft_label\n",
    "            # print(k)\n",
    "    return y_soft"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "58235b68-1b89-4b27-87e1-ed4e77e3fbeb",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1\n",
      "-------------------------------\n",
      "loss: 3.997699  [  128/25020]\n",
      "loss: 3.189255  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 8.1%\n",
      "Epoch 2\n",
      "-------------------------------\n",
      "loss: 3.166546  [  128/25020]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ecpkn/.conda/envs/opacus/lib/python3.8/site-packages/torch/optim/lr_scheduler.py:156: UserWarning: The epoch parameter in `scheduler.step()` was not necessary and is being deprecated where possible. Please use `scheduler.step()` to step the scheduler. During the deprecation, if epoch is different from None, the closed form is used instead of the new chainable form, where available. Please open an issue if you are unable to replicate your use case: https://github.com/pytorch/pytorch/issues/new/choose.\n",
      "  warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loss: 2.800324  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 14.6%\n",
      "Epoch 3\n",
      "-------------------------------\n",
      "loss: 2.875652  [  128/25020]\n",
      "loss: 2.561235  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 20.0%\n",
      "Epoch 4\n",
      "-------------------------------\n",
      "loss: 2.686830  [  128/25020]\n",
      "loss: 2.217433  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 26.2%\n",
      "Epoch 5\n",
      "-------------------------------\n",
      "loss: 2.424104  [  128/25020]\n",
      "loss: 1.853917  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 32.4%\n",
      "Epoch 6\n",
      "-------------------------------\n",
      "loss: 2.110180  [  128/25020]\n",
      "loss: 1.720955  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 38.9%\n",
      "Epoch 7\n",
      "-------------------------------\n",
      "loss: 1.780058  [  128/25020]\n",
      "loss: 1.527670  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 44.4%\n",
      "Epoch 8\n",
      "-------------------------------\n",
      "loss: 1.621081  [  128/25020]\n",
      "loss: 1.343534  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 48.3%\n",
      "Epoch 9\n",
      "-------------------------------\n",
      "loss: 1.509223  [  128/25020]\n",
      "loss: 1.318691  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 51.7%\n",
      "Epoch 10\n",
      "-------------------------------\n",
      "loss: 1.418452  [  128/25020]\n",
      "loss: 1.094885  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 54.8%\n",
      "Epoch 11\n",
      "-------------------------------\n",
      "loss: 1.207603  [  128/25020]\n",
      "loss: 1.125062  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 58.1%\n",
      "Epoch 12\n",
      "-------------------------------\n",
      "loss: 1.188752  [  128/25020]\n",
      "loss: 0.985988  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 59.9%\n",
      "Epoch 13\n",
      "-------------------------------\n",
      "loss: 1.028568  [  128/25020]\n",
      "loss: 1.048485  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 61.8%\n",
      "Epoch 14\n",
      "-------------------------------\n",
      "loss: 1.056795  [  128/25020]\n",
      "loss: 0.883031  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 63.5%\n",
      "Epoch 15\n",
      "-------------------------------\n",
      "loss: 0.904344  [  128/25020]\n",
      "loss: 0.870372  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 64.4%\n",
      "Epoch 16\n",
      "-------------------------------\n",
      "loss: 0.911018  [  128/25020]\n",
      "loss: 0.756592  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 66.2%\n",
      "Epoch 17\n",
      "-------------------------------\n",
      "loss: 0.866654  [  128/25020]\n",
      "loss: 0.849124  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 67.1%\n",
      "Epoch 18\n",
      "-------------------------------\n",
      "loss: 0.842540  [  128/25020]\n",
      "loss: 0.702474  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 69.0%\n",
      "Epoch 19\n",
      "-------------------------------\n",
      "loss: 0.755552  [  128/25020]\n",
      "loss: 0.710633  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 69.8%\n",
      "Epoch 20\n",
      "-------------------------------\n",
      "loss: 0.879607  [  128/25020]\n",
      "loss: 0.705066  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 70.2%\n",
      "Epoch 21\n",
      "-------------------------------\n",
      "loss: 0.804185  [  128/25020]\n",
      "loss: 0.728379  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 71.2%\n",
      "Epoch 22\n",
      "-------------------------------\n",
      "loss: 0.794391  [  128/25020]\n",
      "loss: 0.750568  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 71.8%\n",
      "Epoch 23\n",
      "-------------------------------\n",
      "loss: 0.711563  [  128/25020]\n",
      "loss: 0.581799  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 72.3%\n",
      "Epoch 24\n",
      "-------------------------------\n",
      "loss: 0.657245  [  128/25020]\n",
      "loss: 0.712244  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 73.1%\n",
      "Epoch 25\n",
      "-------------------------------\n",
      "loss: 0.669701  [  128/25020]\n",
      "loss: 0.585648  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 73.3%\n",
      "Epoch 26\n",
      "-------------------------------\n",
      "loss: 0.690468  [  128/25020]\n",
      "loss: 0.658038  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 73.9%\n",
      "Epoch 27\n",
      "-------------------------------\n",
      "loss: 0.696313  [  128/25020]\n",
      "loss: 0.603876  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 74.6%\n",
      "Epoch 28\n",
      "-------------------------------\n",
      "loss: 0.572655  [  128/25020]\n",
      "loss: 0.564942  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 75.2%\n",
      "Epoch 29\n",
      "-------------------------------\n",
      "loss: 0.626039  [  128/25020]\n",
      "loss: 0.624875  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 75.4%\n",
      "Epoch 30\n",
      "-------------------------------\n",
      "loss: 0.599989  [  128/25020]\n",
      "loss: 0.402072  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 83.4%\n",
      "Epoch 31\n",
      "-------------------------------\n",
      "loss: 0.443327  [  128/25020]\n",
      "loss: 0.336586  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 87.1%\n",
      "Epoch 32\n",
      "-------------------------------\n",
      "loss: 0.394907  [  128/25020]\n",
      "loss: 0.336581  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 88.5%\n",
      "Epoch 33\n",
      "-------------------------------\n",
      "loss: 0.341042  [  128/25020]\n",
      "loss: 0.286823  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 89.6%\n",
      "Epoch 34\n",
      "-------------------------------\n",
      "loss: 0.323192  [  128/25020]\n",
      "loss: 0.289934  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 90.6%\n",
      "Epoch 35\n",
      "-------------------------------\n",
      "loss: 0.329369  [  128/25020]\n",
      "loss: 0.284920  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 90.9%\n",
      "Epoch 36\n",
      "-------------------------------\n",
      "loss: 0.308556  [  128/25020]\n",
      "loss: 0.269010  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 91.5%\n",
      "Epoch 37\n",
      "-------------------------------\n",
      "loss: 0.305862  [  128/25020]\n",
      "loss: 0.250270  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 91.9%\n",
      "Epoch 38\n",
      "-------------------------------\n",
      "loss: 0.291214  [  128/25020]\n",
      "loss: 0.246759  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 92.1%\n",
      "Epoch 39\n",
      "-------------------------------\n",
      "loss: 0.298045  [  128/25020]\n",
      "loss: 0.255828  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 92.2%\n",
      "Epoch 40\n",
      "-------------------------------\n",
      "loss: 0.295407  [  128/25020]\n",
      "loss: 0.258234  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 92.7%\n",
      "Epoch 41\n",
      "-------------------------------\n",
      "loss: 0.289596  [  128/25020]\n",
      "loss: 0.249333  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 92.9%\n",
      "Epoch 42\n",
      "-------------------------------\n",
      "loss: 0.277687  [  128/25020]\n",
      "loss: 0.241781  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 93.2%\n",
      "Epoch 43\n",
      "-------------------------------\n",
      "loss: 0.276316  [  128/25020]\n",
      "loss: 0.237593  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 93.2%\n",
      "Epoch 44\n",
      "-------------------------------\n",
      "loss: 0.279254  [  128/25020]\n",
      "loss: 0.242057  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 93.4%\n",
      "Epoch 45\n",
      "-------------------------------\n",
      "loss: 0.277601  [  128/25020]\n",
      "loss: 0.229619  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 93.6%\n",
      "Epoch 46\n",
      "-------------------------------\n",
      "loss: 0.271078  [  128/25020]\n",
      "loss: 0.235993  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 93.6%\n",
      "Epoch 47\n",
      "-------------------------------\n",
      "loss: 0.275788  [  128/25020]\n",
      "loss: 0.232036  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 93.8%\n",
      "Epoch 48\n",
      "-------------------------------\n",
      "loss: 0.268674  [  128/25020]\n",
      "loss: 0.234572  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 94.0%\n",
      "Epoch 49\n",
      "-------------------------------\n",
      "loss: 0.268718  [  128/25020]\n",
      "loss: 0.230093  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 93.9%\n",
      "Epoch 50\n",
      "-------------------------------\n",
      "loss: 0.271261  [  128/25020]\n",
      "loss: 0.229542  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 94.1%\n",
      "Epoch 51\n",
      "-------------------------------\n",
      "loss: 0.260078  [  128/25020]\n",
      "loss: 0.219378  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 94.4%\n",
      "Epoch 52\n",
      "-------------------------------\n",
      "loss: 0.266887  [  128/25020]\n",
      "loss: 0.219975  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 94.7%\n",
      "Epoch 53\n",
      "-------------------------------\n",
      "loss: 0.261775  [  128/25020]\n",
      "loss: 0.220018  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 94.9%\n",
      "Epoch 54\n",
      "-------------------------------\n",
      "loss: 0.259977  [  128/25020]\n",
      "loss: 0.218095  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 94.9%\n",
      "Epoch 55\n",
      "-------------------------------\n",
      "loss: 0.260759  [  128/25020]\n",
      "loss: 0.216239  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.1%\n",
      "Epoch 56\n",
      "-------------------------------\n",
      "loss: 0.254671  [  128/25020]\n",
      "loss: 0.222533  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 94.8%\n",
      "Epoch 57\n",
      "-------------------------------\n",
      "loss: 0.253706  [  128/25020]\n",
      "loss: 0.218705  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.1%\n",
      "Epoch 58\n",
      "-------------------------------\n",
      "loss: 0.251945  [  128/25020]\n",
      "loss: 0.218713  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.0%\n",
      "Epoch 59\n",
      "-------------------------------\n",
      "loss: 0.255752  [  128/25020]\n",
      "loss: 0.218542  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.1%\n",
      "Epoch 60\n",
      "-------------------------------\n",
      "loss: 0.254538  [  128/25020]\n",
      "loss: 0.218878  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.1%\n",
      "Epoch 61\n",
      "-------------------------------\n",
      "loss: 0.258459  [  128/25020]\n",
      "loss: 0.218168  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.0%\n",
      "Epoch 62\n",
      "-------------------------------\n",
      "loss: 0.259018  [  128/25020]\n",
      "loss: 0.219768  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.1%\n",
      "Epoch 63\n",
      "-------------------------------\n",
      "loss: 0.254522  [  128/25020]\n",
      "loss: 0.219244  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.2%\n",
      "Epoch 64\n",
      "-------------------------------\n",
      "loss: 0.260035  [  128/25020]\n",
      "loss: 0.216876  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.2%\n",
      "Epoch 65\n",
      "-------------------------------\n",
      "loss: 0.252565  [  128/25020]\n",
      "loss: 0.217393  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.2%\n",
      "Epoch 66\n",
      "-------------------------------\n",
      "loss: 0.250500  [  128/25020]\n",
      "loss: 0.221911  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.4%\n",
      "Epoch 67\n",
      "-------------------------------\n",
      "loss: 0.251473  [  128/25020]\n",
      "loss: 0.218321  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.4%\n",
      "Epoch 68\n",
      "-------------------------------\n",
      "loss: 0.252163  [  128/25020]\n",
      "loss: 0.218437  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.3%\n",
      "Epoch 69\n",
      "-------------------------------\n",
      "loss: 0.254856  [  128/25020]\n",
      "loss: 0.216679  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.4%\n",
      "Epoch 70\n",
      "-------------------------------\n",
      "loss: 0.251350  [  128/25020]\n",
      "loss: 0.218031  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.4%\n",
      "Epoch 71\n",
      "-------------------------------\n",
      "loss: 0.249693  [  128/25020]\n",
      "loss: 0.213362  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.5%\n",
      "Epoch 72\n",
      "-------------------------------\n",
      "loss: 0.250638  [  128/25020]\n",
      "loss: 0.215324  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.5%\n",
      "Epoch 73\n",
      "-------------------------------\n",
      "loss: 0.249407  [  128/25020]\n",
      "loss: 0.220302  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.6%\n",
      "Epoch 74\n",
      "-------------------------------\n",
      "loss: 0.250299  [  128/25020]\n",
      "loss: 0.216205  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.6%\n",
      "Epoch 75\n",
      "-------------------------------\n",
      "loss: 0.251177  [  128/25020]\n",
      "loss: 0.215158  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.5%\n",
      "Epoch 76\n",
      "-------------------------------\n",
      "loss: 0.249762  [  128/25020]\n",
      "loss: 0.212690  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.5%\n",
      "Epoch 77\n",
      "-------------------------------\n",
      "loss: 0.251919  [  128/25020]\n",
      "loss: 0.215781  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.5%\n",
      "Epoch 78\n",
      "-------------------------------\n",
      "loss: 0.250893  [  128/25020]\n",
      "loss: 0.212013  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.6%\n",
      "Epoch 79\n",
      "-------------------------------\n",
      "loss: 0.253326  [  128/25020]\n",
      "loss: 0.214138  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.6%\n",
      "Epoch 80\n",
      "-------------------------------\n",
      "loss: 0.249624  [  128/25020]\n",
      "loss: 0.212940  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.4%\n",
      "Epoch 81\n",
      "-------------------------------\n",
      "loss: 0.251635  [  128/25020]\n",
      "loss: 0.215671  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.5%\n",
      "Epoch 82\n",
      "-------------------------------\n",
      "loss: 0.249112  [  128/25020]\n",
      "loss: 0.213266  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.5%\n",
      "Epoch 83\n",
      "-------------------------------\n",
      "loss: 0.250097  [  128/25020]\n",
      "loss: 0.215131  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.6%\n",
      "Epoch 84\n",
      "-------------------------------\n",
      "loss: 0.246940  [  128/25020]\n",
      "loss: 0.211495  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.5%\n",
      "Epoch 85\n",
      "-------------------------------\n",
      "loss: 0.250071  [  128/25020]\n",
      "loss: 0.212773  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.8%\n",
      "Epoch 86\n",
      "-------------------------------\n",
      "loss: 0.251580  [  128/25020]\n",
      "loss: 0.213583  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.6%\n",
      "Epoch 87\n",
      "-------------------------------\n",
      "loss: 0.250632  [  128/25020]\n",
      "loss: 0.212109  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.7%\n",
      "Epoch 88\n",
      "-------------------------------\n",
      "loss: 0.252862  [  128/25020]\n",
      "loss: 0.211341  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.6%\n",
      "Epoch 89\n",
      "-------------------------------\n",
      "loss: 0.248070  [  128/25020]\n",
      "loss: 0.213846  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.6%\n",
      "Epoch 90\n",
      "-------------------------------\n",
      "loss: 0.251071  [  128/25020]\n",
      "loss: 0.213333  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.7%\n",
      "Epoch 91\n",
      "-------------------------------\n",
      "loss: 0.252565  [  128/25020]\n",
      "loss: 0.214182  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.8%\n",
      "Epoch 92\n",
      "-------------------------------\n",
      "loss: 0.250979  [  128/25020]\n",
      "loss: 0.212945  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.6%\n",
      "Epoch 93\n",
      "-------------------------------\n",
      "loss: 0.247984  [  128/25020]\n",
      "loss: 0.213553  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.6%\n",
      "Epoch 94\n",
      "-------------------------------\n",
      "loss: 0.250336  [  128/25020]\n",
      "loss: 0.213895  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.7%\n",
      "Epoch 95\n",
      "-------------------------------\n",
      "loss: 0.251689  [  128/25020]\n",
      "loss: 0.214380  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.6%\n",
      "Epoch 96\n",
      "-------------------------------\n",
      "loss: 0.251703  [  128/25020]\n",
      "loss: 0.211273  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.6%\n",
      "Epoch 97\n",
      "-------------------------------\n",
      "loss: 0.255484  [  128/25020]\n",
      "loss: 0.213247  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.7%\n",
      "Epoch 98\n",
      "-------------------------------\n",
      "loss: 0.251833  [  128/25020]\n",
      "loss: 0.215070  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.6%\n",
      "Epoch 99\n",
      "-------------------------------\n",
      "loss: 0.249552  [  128/25020]\n",
      "loss: 0.215604  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.5%\n",
      "Epoch 100\n",
      "-------------------------------\n",
      "loss: 0.248148  [  128/25020]\n",
      "loss: 0.215074  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.7%\n",
      "Done!\n",
      "Test Error: \n",
      " Accuracy: 62.5%, Avg loss: 1.423908 \n",
      "\n",
      "Test Error: \n",
      " Accuracy: 74.9%, Avg loss: 0.985340 \n",
      "\n",
      " Error: \n",
      " Accuracy: 74.9%  \n",
      "\n",
      "AUC value is: 0.7212114007752966\n",
      "Accuracy is: 0.62462\n",
      "Epoch 1\n",
      "-------------------------------\n",
      "loss: 3.565899  [  128/25020]\n",
      "loss: 2.899103  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 8.2%\n",
      "Epoch 2\n",
      "-------------------------------\n",
      "loss: 2.822154  [  128/25020]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ecpkn/.conda/envs/opacus/lib/python3.8/site-packages/torch/optim/lr_scheduler.py:156: UserWarning: The epoch parameter in `scheduler.step()` was not necessary and is being deprecated where possible. Please use `scheduler.step()` to step the scheduler. During the deprecation, if epoch is different from None, the closed form is used instead of the new chainable form, where available. Please open an issue if you are unable to replicate your use case: https://github.com/pytorch/pytorch/issues/new/choose.\n",
      "  warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loss: 2.487991  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 16.0%\n",
      "Epoch 3\n",
      "-------------------------------\n",
      "loss: 2.607849  [  128/25020]\n",
      "loss: 2.156226  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 22.2%\n",
      "Epoch 4\n",
      "-------------------------------\n",
      "loss: 2.349002  [  128/25020]\n",
      "loss: 1.907642  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 28.3%\n",
      "Epoch 5\n",
      "-------------------------------\n",
      "loss: 2.021014  [  128/25020]\n",
      "loss: 1.670536  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 34.8%\n",
      "Epoch 6\n",
      "-------------------------------\n",
      "loss: 1.757572  [  128/25020]\n",
      "loss: 1.454094  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 40.7%\n",
      "Epoch 7\n",
      "-------------------------------\n",
      "loss: 1.624166  [  128/25020]\n",
      "loss: 1.350689  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 45.5%\n",
      "Epoch 8\n",
      "-------------------------------\n",
      "loss: 1.461268  [  128/25020]\n",
      "loss: 1.113628  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 49.7%\n",
      "Epoch 9\n",
      "-------------------------------\n",
      "loss: 1.329399  [  128/25020]\n",
      "loss: 1.019732  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 53.1%\n",
      "Epoch 10\n",
      "-------------------------------\n",
      "loss: 1.161304  [  128/25020]\n",
      "loss: 0.956625  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 56.0%\n",
      "Epoch 11\n",
      "-------------------------------\n",
      "loss: 1.138046  [  128/25020]\n",
      "loss: 0.849134  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 58.7%\n",
      "Epoch 12\n",
      "-------------------------------\n",
      "loss: 1.050860  [  128/25020]\n",
      "loss: 0.975296  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 60.1%\n",
      "Epoch 13\n",
      "-------------------------------\n",
      "loss: 0.895090  [  128/25020]\n",
      "loss: 0.749738  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 61.9%\n",
      "Epoch 14\n",
      "-------------------------------\n",
      "loss: 0.892697  [  128/25020]\n",
      "loss: 0.789022  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 62.6%\n",
      "Epoch 15\n",
      "-------------------------------\n",
      "loss: 0.843345  [  128/25020]\n",
      "loss: 0.728638  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 63.9%\n",
      "Epoch 16\n",
      "-------------------------------\n",
      "loss: 0.819339  [  128/25020]\n",
      "loss: 0.722585  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 65.0%\n",
      "Epoch 17\n",
      "-------------------------------\n",
      "loss: 0.781183  [  128/25020]\n",
      "loss: 0.694588  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 66.7%\n",
      "Epoch 18\n",
      "-------------------------------\n",
      "loss: 0.829712  [  128/25020]\n",
      "loss: 0.722837  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 66.6%\n",
      "Epoch 19\n",
      "-------------------------------\n",
      "loss: 0.741421  [  128/25020]\n",
      "loss: 0.658768  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 67.9%\n",
      "Epoch 20\n",
      "-------------------------------\n",
      "loss: 0.706630  [  128/25020]\n",
      "loss: 0.641546  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 69.5%\n",
      "Epoch 21\n",
      "-------------------------------\n",
      "loss: 0.743554  [  128/25020]\n",
      "loss: 0.629065  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 69.1%\n",
      "Epoch 22\n",
      "-------------------------------\n",
      "loss: 0.675293  [  128/25020]\n",
      "loss: 0.643646  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 70.1%\n",
      "Epoch 23\n",
      "-------------------------------\n",
      "loss: 0.690719  [  128/25020]\n",
      "loss: 0.584603  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 70.9%\n",
      "Epoch 24\n",
      "-------------------------------\n",
      "loss: 0.616975  [  128/25020]\n",
      "loss: 0.737543  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 71.6%\n",
      "Epoch 25\n",
      "-------------------------------\n",
      "loss: 0.627166  [  128/25020]\n",
      "loss: 0.575991  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 72.0%\n",
      "Epoch 26\n",
      "-------------------------------\n",
      "loss: 0.603640  [  128/25020]\n",
      "loss: 0.590642  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 72.7%\n",
      "Epoch 27\n",
      "-------------------------------\n",
      "loss: 0.664175  [  128/25020]\n",
      "loss: 0.552540  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 72.3%\n",
      "Epoch 28\n",
      "-------------------------------\n",
      "loss: 0.553312  [  128/25020]\n",
      "loss: 0.557454  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 73.6%\n",
      "Epoch 29\n",
      "-------------------------------\n",
      "loss: 0.625422  [  128/25020]\n",
      "loss: 0.588274  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 73.4%\n",
      "Epoch 30\n",
      "-------------------------------\n",
      "loss: 0.622057  [  128/25020]\n",
      "loss: 0.483343  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 80.2%\n",
      "Epoch 31\n",
      "-------------------------------\n",
      "loss: 0.414649  [  128/25020]\n",
      "loss: 0.365753  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 84.0%\n",
      "Epoch 32\n",
      "-------------------------------\n",
      "loss: 0.407606  [  128/25020]\n",
      "loss: 0.342852  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 86.0%\n",
      "Epoch 33\n",
      "-------------------------------\n",
      "loss: 0.369736  [  128/25020]\n",
      "loss: 0.329413  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 86.7%\n",
      "Epoch 34\n",
      "-------------------------------\n",
      "loss: 0.355405  [  128/25020]\n",
      "loss: 0.315264  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 87.5%\n",
      "Epoch 35\n",
      "-------------------------------\n",
      "loss: 0.365210  [  128/25020]\n",
      "loss: 0.332635  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 88.1%\n",
      "Epoch 36\n",
      "-------------------------------\n",
      "loss: 0.344763  [  128/25020]\n",
      "loss: 0.308207  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 88.7%\n",
      "Epoch 37\n",
      "-------------------------------\n",
      "loss: 0.357630  [  128/25020]\n",
      "loss: 0.303442  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 89.1%\n",
      "Epoch 38\n",
      "-------------------------------\n",
      "loss: 0.339031  [  128/25020]\n",
      "loss: 0.311355  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 89.3%\n",
      "Epoch 39\n",
      "-------------------------------\n",
      "loss: 0.341977  [  128/25020]\n",
      "loss: 0.290907  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 89.7%\n",
      "Epoch 40\n",
      "-------------------------------\n",
      "loss: 0.339954  [  128/25020]\n",
      "loss: 0.287978  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 90.1%\n",
      "Epoch 41\n",
      "-------------------------------\n",
      "loss: 0.335192  [  128/25020]\n",
      "loss: 0.288090  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 90.2%\n",
      "Epoch 42\n",
      "-------------------------------\n",
      "loss: 0.326171  [  128/25020]\n",
      "loss: 0.279052  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 90.7%\n",
      "Epoch 43\n",
      "-------------------------------\n",
      "loss: 0.322207  [  128/25020]\n",
      "loss: 0.282898  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 90.8%\n",
      "Epoch 44\n",
      "-------------------------------\n",
      "loss: 0.327372  [  128/25020]\n",
      "loss: 0.281448  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 90.9%\n",
      "Epoch 45\n",
      "-------------------------------\n",
      "loss: 0.325519  [  128/25020]\n",
      "loss: 0.272803  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 91.2%\n",
      "Epoch 46\n",
      "-------------------------------\n",
      "loss: 0.321224  [  128/25020]\n",
      "loss: 0.283670  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 91.6%\n",
      "Epoch 47\n",
      "-------------------------------\n",
      "loss: 0.316801  [  128/25020]\n",
      "loss: 0.277828  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 91.5%\n",
      "Epoch 48\n",
      "-------------------------------\n",
      "loss: 0.321695  [  128/25020]\n",
      "loss: 0.280237  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 91.7%\n",
      "Epoch 49\n",
      "-------------------------------\n",
      "loss: 0.323362  [  128/25020]\n",
      "loss: 0.276934  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 91.8%\n",
      "Epoch 50\n",
      "-------------------------------\n",
      "loss: 0.317180  [  128/25020]\n",
      "loss: 0.269548  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 92.2%\n",
      "Epoch 51\n",
      "-------------------------------\n",
      "loss: 0.314052  [  128/25020]\n",
      "loss: 0.268729  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 92.5%\n",
      "Epoch 52\n",
      "-------------------------------\n",
      "loss: 0.315467  [  128/25020]\n",
      "loss: 0.266849  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 92.7%\n",
      "Epoch 53\n",
      "-------------------------------\n",
      "loss: 0.313051  [  128/25020]\n",
      "loss: 0.269168  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 92.9%\n",
      "Epoch 54\n",
      "-------------------------------\n",
      "loss: 0.314576  [  128/25020]\n",
      "loss: 0.265820  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 92.9%\n",
      "Epoch 55\n",
      "-------------------------------\n",
      "loss: 0.314298  [  128/25020]\n",
      "loss: 0.264366  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 92.9%\n",
      "Epoch 56\n",
      "-------------------------------\n",
      "loss: 0.308693  [  128/25020]\n",
      "loss: 0.262931  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 92.8%\n",
      "Epoch 57\n",
      "-------------------------------\n",
      "loss: 0.310633  [  128/25020]\n",
      "loss: 0.265934  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 93.1%\n",
      "Epoch 58\n",
      "-------------------------------\n",
      "loss: 0.314798  [  128/25020]\n",
      "loss: 0.261048  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 93.0%\n",
      "Epoch 59\n",
      "-------------------------------\n",
      "loss: 0.308273  [  128/25020]\n",
      "loss: 0.267764  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 93.1%\n",
      "Epoch 60\n",
      "-------------------------------\n",
      "loss: 0.309459  [  128/25020]\n",
      "loss: 0.259337  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 93.1%\n",
      "Epoch 61\n",
      "-------------------------------\n",
      "loss: 0.305984  [  128/25020]\n",
      "loss: 0.266292  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 93.4%\n",
      "Epoch 62\n",
      "-------------------------------\n",
      "loss: 0.311051  [  128/25020]\n",
      "loss: 0.263082  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 93.3%\n",
      "Epoch 63\n",
      "-------------------------------\n",
      "loss: 0.310598  [  128/25020]\n",
      "loss: 0.261908  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 93.3%\n",
      "Epoch 64\n",
      "-------------------------------\n",
      "loss: 0.309681  [  128/25020]\n",
      "loss: 0.262102  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 93.3%\n",
      "Epoch 65\n",
      "-------------------------------\n",
      "loss: 0.313315  [  128/25020]\n",
      "loss: 0.260412  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 93.4%\n",
      "Epoch 66\n",
      "-------------------------------\n",
      "loss: 0.307872  [  128/25020]\n",
      "loss: 0.262215  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 93.5%\n",
      "Epoch 67\n",
      "-------------------------------\n",
      "loss: 0.308468  [  128/25020]\n",
      "loss: 0.261821  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 93.5%\n",
      "Epoch 68\n",
      "-------------------------------\n",
      "loss: 0.309200  [  128/25020]\n",
      "loss: 0.264939  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 93.4%\n",
      "Epoch 69\n",
      "-------------------------------\n",
      "loss: 0.306524  [  128/25020]\n",
      "loss: 0.262478  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 93.6%\n",
      "Epoch 70\n",
      "-------------------------------\n",
      "loss: 0.309500  [  128/25020]\n",
      "loss: 0.263681  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 93.6%\n",
      "Epoch 71\n",
      "-------------------------------\n",
      "loss: 0.308376  [  128/25020]\n",
      "loss: 0.260098  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 93.8%\n",
      "Epoch 72\n",
      "-------------------------------\n",
      "loss: 0.301801  [  128/25020]\n",
      "loss: 0.259034  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 93.5%\n",
      "Epoch 73\n",
      "-------------------------------\n",
      "loss: 0.306081  [  128/25020]\n",
      "loss: 0.258821  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 93.7%\n",
      "Epoch 74\n",
      "-------------------------------\n",
      "loss: 0.305132  [  128/25020]\n",
      "loss: 0.259720  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 93.7%\n",
      "Epoch 75\n",
      "-------------------------------\n",
      "loss: 0.304409  [  128/25020]\n",
      "loss: 0.263092  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 93.8%\n",
      "Epoch 76\n",
      "-------------------------------\n",
      "loss: 0.306587  [  128/25020]\n",
      "loss: 0.259811  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 93.8%\n",
      "Epoch 77\n",
      "-------------------------------\n",
      "loss: 0.303776  [  128/25020]\n",
      "loss: 0.258964  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 93.6%\n",
      "Epoch 78\n",
      "-------------------------------\n",
      "loss: 0.307641  [  128/25020]\n",
      "loss: 0.262076  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 93.8%\n",
      "Epoch 79\n",
      "-------------------------------\n",
      "loss: 0.305623  [  128/25020]\n",
      "loss: 0.260711  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 93.9%\n",
      "Epoch 80\n",
      "-------------------------------\n",
      "loss: 0.303641  [  128/25020]\n",
      "loss: 0.258346  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 93.7%\n",
      "Epoch 81\n",
      "-------------------------------\n",
      "loss: 0.303476  [  128/25020]\n",
      "loss: 0.261645  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 93.6%\n",
      "Epoch 82\n",
      "-------------------------------\n",
      "loss: 0.305751  [  128/25020]\n",
      "loss: 0.263554  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 93.8%\n",
      "Epoch 83\n",
      "-------------------------------\n",
      "loss: 0.305460  [  128/25020]\n",
      "loss: 0.260289  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 93.9%\n",
      "Epoch 84\n",
      "-------------------------------\n",
      "loss: 0.303663  [  128/25020]\n",
      "loss: 0.259236  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 93.7%\n",
      "Epoch 85\n",
      "-------------------------------\n",
      "loss: 0.304689  [  128/25020]\n",
      "loss: 0.260234  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 93.7%\n",
      "Epoch 86\n",
      "-------------------------------\n",
      "loss: 0.304870  [  128/25020]\n",
      "loss: 0.265054  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 93.8%\n",
      "Epoch 87\n",
      "-------------------------------\n",
      "loss: 0.303217  [  128/25020]\n",
      "loss: 0.258846  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 94.2%\n",
      "Epoch 88\n",
      "-------------------------------\n",
      "loss: 0.302492  [  128/25020]\n",
      "loss: 0.259181  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 93.9%\n",
      "Epoch 89\n",
      "-------------------------------\n",
      "loss: 0.305036  [  128/25020]\n",
      "loss: 0.260511  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 93.8%\n",
      "Epoch 90\n",
      "-------------------------------\n",
      "loss: 0.306741  [  128/25020]\n",
      "loss: 0.262163  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 94.0%\n",
      "Epoch 91\n",
      "-------------------------------\n",
      "loss: 0.307256  [  128/25020]\n",
      "loss: 0.261619  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 94.0%\n",
      "Epoch 92\n",
      "-------------------------------\n",
      "loss: 0.303047  [  128/25020]\n",
      "loss: 0.261412  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 93.8%\n",
      "Epoch 93\n",
      "-------------------------------\n",
      "loss: 0.307982  [  128/25020]\n",
      "loss: 0.259725  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 94.0%\n",
      "Epoch 94\n",
      "-------------------------------\n",
      "loss: 0.303173  [  128/25020]\n",
      "loss: 0.261234  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 93.7%\n",
      "Epoch 95\n",
      "-------------------------------\n",
      "loss: 0.303586  [  128/25020]\n",
      "loss: 0.259123  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 94.0%\n",
      "Epoch 96\n",
      "-------------------------------\n",
      "loss: 0.307637  [  128/25020]\n",
      "loss: 0.258800  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 94.0%\n",
      "Epoch 97\n",
      "-------------------------------\n",
      "loss: 0.302959  [  128/25020]\n",
      "loss: 0.257961  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 93.9%\n",
      "Epoch 98\n",
      "-------------------------------\n",
      "loss: 0.302451  [  128/25020]\n",
      "loss: 0.262680  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 93.9%\n",
      "Epoch 99\n",
      "-------------------------------\n",
      "loss: 0.302863  [  128/25020]\n",
      "loss: 0.260069  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 93.9%\n",
      "Epoch 100\n",
      "-------------------------------\n",
      "loss: 0.303417  [  128/25020]\n",
      "loss: 0.260142  [12928/25020]\n",
      "Train Error: \n",
      " Accuracy: 94.0%\n",
      "Done!\n",
      "Test Error: \n",
      " Accuracy: 60.4%, Avg loss: 1.518693 \n",
      "\n",
      "Test Error: \n",
      " Accuracy: 69.8%, Avg loss: 1.166083 \n",
      "\n",
      " Error: \n",
      " Accuracy: 69.8%  \n",
      "\n",
      "AUC value is: 0.663031648340255\n",
      "Accuracy is: 0.5904\n"
     ]
    }
   ],
   "source": [
    "# risk_bound_list = [0.01, 0.05, 0.1, 0.5, 1, 2, 4, 6, 10]\n",
    "risk_bound_list = [1]\n",
    "# k_init_list = [0, 0.2, 0.4, 0.5, 0.6, 0.8, 1]\n",
    "k_init_list = [0.5, 0]\n",
    "model_test_correct = []\n",
    "average_lira_attack = []\n",
    "average_base_attack = []\n",
    "average_shadow_attack = []\n",
    "risk_base_attack = []\n",
    "risk_lira_attack = []\n",
    "risk_shadow_attack = []\n",
    "top_risk = 500\n",
    "for risk_bound in risk_bound_list: # 对隐私风险约束进行遍历\n",
    "    model_test_correct_t = []\n",
    "    average_lira_attack_t = []\n",
    "    average_base_attack_t = []\n",
    "    average_shadow_attack_t = []\n",
    "    risk_base_attack_t = []\n",
    "    risk_lira_attack_t = []\n",
    "    risk_shadow_attack_t = []\n",
    "    for k_init in k_init_list: # 对隐私风险约束进行遍历\n",
    "        y_soft = label_fix(risk_bound, k_init, y_onehot, conf_out_mean, pri_risk_all, mem_data)\n",
    "        x = X_data[mem_label]\n",
    "        y = y_soft\n",
    "        # y = y_onehot\n",
    "        train_data = CustomDataset(x, y, model_transform)\n",
    "        train_dataloader = DataLoader(train_data, batch_size=batch_size)\n",
    "        TargetModel = globals()['create_{}_model'.format(model)](100)\n",
    "        TargetModel.to(device)\n",
    "        # loss_fn = nn.KLDivLoss()  # 使用KL散度损失\n",
    "        loss_fn = nn.CrossEntropyLoss()\n",
    "        # optimizer = torch.optim.Adam(TargetModel.parameters(), lr=LEARNING_RATE)\n",
    "        optimizer = torch.optim.SGD(TargetModel.parameters(), lr=LEARNING_RATE, momentum=0.9, weight_decay=5e-4)\n",
    "        train_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30, 50, 70, 90], gamma=0.2)\n",
    "        for t in range(100):\n",
    "            print(f\"Epoch {t+1}\\n-------------------------------\")\n",
    "            if t > 0:\n",
    "                train_scheduler.step(t+1)\n",
    "            # train_softlabel(train_dataloader, TargetModel, loss_fn, optimizer, device)\n",
    "            train_onehot(train_dataloader, TargetModel, loss_fn, optimizer, device)\n",
    "        print(\"Done!\")\n",
    "        # 训练完目标模型\n",
    "        loss_fn = nn.CrossEntropyLoss()\n",
    "        correct = evaluate(test_dataloader, TargetModel, loss_fn, device)\n",
    "        model_test_correct_t.append(correct)\n",
    "        # 执行基线攻击\n",
    "        pred_result = base_attack(all_dataloader, TargetModel, loss_fn, device)\n",
    "        accuracy = metrics.accuracy_score(train_keep[0], pred_result)\n",
    "        average_base_attack_t.append(accuracy)\n",
    "        pred_clip = pred_result[pri_risk_rank[:top_risk]]\n",
    "        mem_clip = train_keep[0][pri_risk_rank[:top_risk]]\n",
    "        accuracy = metrics.accuracy_score(mem_clip, pred_clip)\n",
    "        risk_base_attack_t.append(accuracy)\n",
    "        # 执行似然比攻击\n",
    "        _, score = get_score_from_model(all_dataloader, TargetModel, device)\n",
    "        pred_result = LIRA_attack(train_keep, score_all, score, train_keep[0])\n",
    "        accuracy = evaluate_ROC(pred_result, train_keep[0], threshold=0)\n",
    "        average_lira_attack_t.append(accuracy)\n",
    "        pred_clip = pred_result[pri_risk_rank[:top_risk]]\n",
    "        mem_clip = train_keep[0][pri_risk_rank[:top_risk]]\n",
    "        pred_clip = pred_clip > 0\n",
    "        accuracy = metrics.accuracy_score(mem_clip, pred_clip)\n",
    "        risk_lira_attack_t.append(accuracy)\n",
    "        # 执行影子模型攻击\n",
    "\n",
    "        # # 提取数据集在模型上的置信度输出\n",
    "        # targetX, _ = get_model_pred(all_dataloader, TargetModel, device)\n",
    "        # targetX = targetX.detach().cpu().numpy()\n",
    "        # targetX = targetX.astype(np.float32)\n",
    "        \n",
    "        # top_k = 3\n",
    "        # if top_k:\n",
    "        #     # 仅使用概率向量的前3个值\n",
    "        #     targetX, _ = get_top_k_conf(top_k, targetX, targetX)\n",
    "\n",
    "        # shadow_attack_data = CustomDataset(targetX, train_keep[0], attack_transform)\n",
    "        # shadow_attack_dataloader = DataLoader(shadow_attack_data, batch_size=batch_size, shuffle=False)\n",
    "        # attack_test_scores, attack_test_mem = get_attack_pred(shadow_attack_dataloader, attack_model, device)\n",
    "        # attack_test_scores, attack_test_mem = attack_test_scores.detach().cpu().numpy(), attack_test_mem.detach().cpu().numpy()\n",
    "        \n",
    "        # pred_clip = attack_test_scores[pri_risk_rank[:top_risk]]\n",
    "        # mem_clip = train_keep[0][pri_risk_rank[:top_risk]]\n",
    "        # accuracy = evaluate_ROC(pred_clip, mem_clip)\n",
    "        # risk_shadow_attack_t.append(accuracy)\n",
    "        \n",
    "        # accuracy = evaluate_ROC(attack_test_scores, attack_test_mem)\n",
    "        # average_shadow_attack_t.append(accuracy)\n",
    "    \n",
    "    \n",
    "    model_test_correct.append(model_test_correct_t)\n",
    "    average_lira_attack.append(average_lira_attack_t)\n",
    "    average_base_attack.append(average_base_attack_t)\n",
    "    average_shadow_attack.append(average_shadow_attack_t)\n",
    "    risk_base_attack.append(risk_base_attack_t)\n",
    "    risk_lira_attack.append(risk_lira_attack_t)\n",
    "    risk_shadow_attack.append(risk_shadow_attack_t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "1d4fc723-0472-4f9f-8405-df67c9548193",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[0.6248, 0.6035]]"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_test_correct"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "b3adfb5d-5648-4027-9def-58f64132744e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[0.62462, 0.5904]]"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "average_lira_attack"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "d28354fd-272d-4aa8-bea3-ed8e84014f95",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[]]"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "average_shadow_attack"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "d67ff7ba-810a-4e12-8738-59de800de44a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[0.668, 0.666]]"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "risk_base_attack"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "f5f89dc8-5005-4533-8a86-e0a0f6bad75d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[0.666, 0.654]]"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "risk_lira_attack"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "5d424ecd-a1b3-4202-bade-e2d6b6cd8d02",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[0.62462, 0.5904]]"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "average_lira_attack"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "000aa423-6565-431c-ae30-2dfa9dab1349",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[]]"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "risk_shadow_attack"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "a96c6909-ab59-4bca-8e1f-549e76c90233",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[0.59816, 0.5713]]"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "average_base_attack"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7dbb816-defd-481c-ab96-d2895a0068eb",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "6b396699-9f95-44f0-94b0-4c9237096729",
   "metadata": {},
   "outputs": [],
   "source": [
    "mem_clip = train_keep[0][pri_risk_rank[:1000]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "9ff458f1-f1b8-486d-89cf-201b527eae01",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "491"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mem_clip.sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "596c23dd-ee16-4b3f-aec6-9f4f8cfd60b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 批量训练子模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "92ffd6c3-3517-459a-aff3-bf28baa32f7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# for i in range(dist_num):\n",
    "#     dis_train = mem_data[dist_keep[i]]\n",
    "#     x = X_data[dis_train]\n",
    "#     y = Y_data[dis_train]\n",
    "#     train_data = CustomDataset(x, y, model_transform)\n",
    "#     train_dataloader = DataLoader(train_data, batch_size=batch_size)\n",
    "#     ReferenceModel = globals()['create_{}_model'.format(model)](y.max()+1, data_name)\n",
    "#     ReferenceModel.to(device)\n",
    "#     loss_fn = nn.CrossEntropyLoss()\n",
    "#     optimizer = torch.optim.Adam(ReferenceModel.parameters(), lr=LEARNING_RATE)\n",
    "#     for t in range(epochs):\n",
    "#         print(f\"Epoch {t+1}\\n-------------------------------\")\n",
    "#         train(train_dataloader, ReferenceModel, loss_fn, optimizer, device)\n",
    "#     print(\"Done!\")\n",
    "#     weight_path = os.path.join(weight_dir, \"{}_{}_epoch{}_dist0_model{}.pth\".format(data_name, model, epochs, i))\n",
    "#     torch.save(ReferenceModel.state_dict(), weight_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "deb8edb3-83b7-42a6-ad54-c75a2ddd8b3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 对原始标签进行自蒸馏\n",
    "# 读取所有训练集数据的置信度输出并保存\n",
    "# "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "812def3e-b726-4a14-b532-0d0b325f08a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = X_data[mem_label]\n",
    "y = Y_data[mem_label]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "84ac3032-87fd-44ce-a5c7-cc920e16bf1e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(25020,)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "d4e6f7f6-ff0e-4b4c-bc84-5f46a5e15c65",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " Error: \n",
      " Accuracy: 83.4%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 84.0%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 81.5%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 83.0%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 83.4%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 82.8%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 84.3%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 83.2%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 82.7%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 82.2%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 83.8%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 80.4%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 81.2%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 80.1%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 80.8%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 85.1%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 83.2%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 84.6%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 82.4%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 81.1%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 79.3%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 82.1%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 83.4%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 84.1%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 84.4%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 81.4%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 82.7%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 83.3%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 82.6%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 82.0%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 83.2%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 80.9%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 81.3%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 82.0%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 84.3%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 82.3%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 83.2%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 79.3%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 83.3%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 80.8%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 82.4%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 83.6%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 83.1%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 82.2%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 81.7%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 82.8%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 82.7%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 83.3%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 81.4%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 84.7%  \n",
      "\n"
     ]
    }
   ],
   "source": [
    "weight_part = \"{}_{}_epoch{}_dist0_model\".format(data_name, model, epochs)\n",
    "conf_data_train, label_data_train, _ = load_score_data_all(x, y, weight_dir, dist_num, data_name, model, weight_part, model_transform, batch_size, device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b9edaf5-abe7-4dd6-811a-00685073ebe0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "30b3ba83-3cdf-46ff-b440-405e10aa9d33",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(50, 25020, 10)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "conf_data_train.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "7af84727-1418-4384-89fb-ee33790e72c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "conf_in = []\n",
    "conf_out = []\n",
    "for i in range(conf_data_train.shape[1]):\n",
    "    conf_in.append((conf_data_train[dist_keep[:,i],i]))\n",
    "    conf_out.append((conf_data_train[~dist_keep[:,i],i]))\n",
    "conf_in = np.array(conf_in)\n",
    "conf_out = np.array(conf_out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "117e80ea-6e76-4656-b679-8d49afca7ca2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(25020, 25, 10)"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "conf_in.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "361b4887-2694-45ff-9291-f8e62b3c8646",
   "metadata": {},
   "outputs": [],
   "source": [
    "conf_in_mean = np.median(conf_in, 1)\n",
    "conf_out_mean = np.median(conf_out, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "c4afc608-ba50-41b4-913e-716be131318c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(25020, 10)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "conf_out_mean.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "368ffd19-1f56-48dd-a074-0adc0c86a88d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "51a1b4d1-07e7-4a1a-b2db-2b34ffc665f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 尝试用非成员标签去训练模型\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69145756-b20d-40d3-a478-44881e286bea",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "058d9770-37ed-4818-88a6-ba976886e00d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "4ea41d5e-74a4-4a90-a18d-e7d34c304b70",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1\n",
      "-------------------------------\n",
      "loss: 1.801403  [   64/25020]\n",
      "loss: 0.824510  [ 6464/25020]\n",
      "loss: 0.918130  [12864/25020]\n",
      "loss: 0.528519  [19264/25020]\n",
      "Train Error: \n",
      " Accuracy: 59.5%\n",
      "Epoch 2\n",
      "-------------------------------\n",
      "loss: 0.532350  [   64/25020]\n",
      "loss: 0.364616  [ 6464/25020]\n",
      "loss: 0.545576  [12864/25020]\n",
      "loss: 0.330572  [19264/25020]\n",
      "Train Error: \n",
      " Accuracy: 75.4%\n",
      "Epoch 3\n",
      "-------------------------------\n",
      "loss: 0.324737  [   64/25020]\n",
      "loss: 0.285437  [ 6464/25020]\n",
      "loss: 0.454724  [12864/25020]\n",
      "loss: 0.233015  [19264/25020]\n",
      "Train Error: \n",
      " Accuracy: 80.8%\n",
      "Epoch 4\n",
      "-------------------------------\n",
      "loss: 0.298071  [   64/25020]\n",
      "loss: 0.225670  [ 6464/25020]\n",
      "loss: 0.349112  [12864/25020]\n",
      "loss: 0.175004  [19264/25020]\n",
      "Train Error: \n",
      " Accuracy: 83.8%\n",
      "Epoch 5\n",
      "-------------------------------\n",
      "loss: 0.240583  [   64/25020]\n",
      "loss: 0.211505  [ 6464/25020]\n",
      "loss: 0.263312  [12864/25020]\n",
      "loss: 0.149625  [19264/25020]\n",
      "Train Error: \n",
      " Accuracy: 85.5%\n",
      "Epoch 6\n",
      "-------------------------------\n",
      "loss: 0.184063  [   64/25020]\n",
      "loss: 0.201455  [ 6464/25020]\n",
      "loss: 0.277144  [12864/25020]\n",
      "loss: 0.130788  [19264/25020]\n",
      "Train Error: \n",
      " Accuracy: 87.3%\n",
      "Epoch 7\n",
      "-------------------------------\n",
      "loss: 0.152649  [   64/25020]\n",
      "loss: 0.146956  [ 6464/25020]\n",
      "loss: 0.209676  [12864/25020]\n",
      "loss: 0.167928  [19264/25020]\n",
      "Train Error: \n",
      " Accuracy: 88.3%\n",
      "Epoch 8\n",
      "-------------------------------\n",
      "loss: 0.169000  [   64/25020]\n",
      "loss: 0.169397  [ 6464/25020]\n",
      "loss: 0.202958  [12864/25020]\n",
      "loss: 0.143935  [19264/25020]\n",
      "Train Error: \n",
      " Accuracy: 88.9%\n",
      "Epoch 9\n",
      "-------------------------------\n",
      "loss: 0.152008  [   64/25020]\n",
      "loss: 0.134156  [ 6464/25020]\n",
      "loss: 0.247687  [12864/25020]\n",
      "loss: 0.108098  [19264/25020]\n",
      "Train Error: \n",
      " Accuracy: 89.7%\n",
      "Epoch 10\n",
      "-------------------------------\n",
      "loss: 0.115682  [   64/25020]\n",
      "loss: 0.167386  [ 6464/25020]\n",
      "loss: 0.211232  [12864/25020]\n",
      "loss: 0.104994  [19264/25020]\n",
      "Train Error: \n",
      " Accuracy: 90.9%\n",
      "Epoch 11\n",
      "-------------------------------\n",
      "loss: 0.132905  [   64/25020]\n",
      "loss: 0.134737  [ 6464/25020]\n",
      "loss: 0.170954  [12864/25020]\n",
      "loss: 0.091956  [19264/25020]\n",
      "Train Error: \n",
      " Accuracy: 91.5%\n",
      "Epoch 12\n",
      "-------------------------------\n",
      "loss: 0.080937  [   64/25020]\n",
      "loss: 0.137520  [ 6464/25020]\n",
      "loss: 0.144292  [12864/25020]\n",
      "loss: 0.102301  [19264/25020]\n",
      "Train Error: \n",
      " Accuracy: 92.0%\n",
      "Epoch 13\n",
      "-------------------------------\n",
      "loss: 0.083459  [   64/25020]\n",
      "loss: 0.125804  [ 6464/25020]\n",
      "loss: 0.166214  [12864/25020]\n",
      "loss: 0.111353  [19264/25020]\n",
      "Train Error: \n",
      " Accuracy: 92.1%\n",
      "Epoch 14\n",
      "-------------------------------\n",
      "loss: 0.086955  [   64/25020]\n",
      "loss: 0.109707  [ 6464/25020]\n",
      "loss: 0.190748  [12864/25020]\n",
      "loss: 0.120527  [19264/25020]\n",
      "Train Error: \n",
      " Accuracy: 92.6%\n",
      "Epoch 15\n",
      "-------------------------------\n",
      "loss: 0.094356  [   64/25020]\n",
      "loss: 0.116078  [ 6464/25020]\n",
      "loss: 0.204835  [12864/25020]\n",
      "loss: 0.113678  [19264/25020]\n",
      "Train Error: \n",
      " Accuracy: 93.0%\n",
      "Epoch 16\n",
      "-------------------------------\n",
      "loss: 0.078904  [   64/25020]\n",
      "loss: 0.118839  [ 6464/25020]\n",
      "loss: 0.140360  [12864/25020]\n",
      "loss: 0.104875  [19264/25020]\n",
      "Train Error: \n",
      " Accuracy: 92.7%\n",
      "Epoch 17\n",
      "-------------------------------\n",
      "loss: 0.078362  [   64/25020]\n",
      "loss: 0.100908  [ 6464/25020]\n",
      "loss: 0.148350  [12864/25020]\n",
      "loss: 0.081567  [19264/25020]\n",
      "Train Error: \n",
      " Accuracy: 92.8%\n",
      "Epoch 18\n",
      "-------------------------------\n",
      "loss: 0.078305  [   64/25020]\n",
      "loss: 0.106324  [ 6464/25020]\n",
      "loss: 0.129427  [12864/25020]\n",
      "loss: 0.091890  [19264/25020]\n",
      "Train Error: \n",
      " Accuracy: 93.0%\n",
      "Epoch 19\n",
      "-------------------------------\n",
      "loss: 0.110341  [   64/25020]\n",
      "loss: 0.132545  [ 6464/25020]\n",
      "loss: 0.131886  [12864/25020]\n",
      "loss: 0.104535  [19264/25020]\n",
      "Train Error: \n",
      " Accuracy: 93.3%\n",
      "Epoch 20\n",
      "-------------------------------\n",
      "loss: 0.072668  [   64/25020]\n",
      "loss: 0.130856  [ 6464/25020]\n",
      "loss: 0.150544  [12864/25020]\n",
      "loss: 0.083680  [19264/25020]\n",
      "Train Error: \n",
      " Accuracy: 93.4%\n",
      "Epoch 21\n",
      "-------------------------------\n",
      "loss: 0.074430  [   64/25020]\n",
      "loss: 0.104937  [ 6464/25020]\n",
      "loss: 0.159583  [12864/25020]\n",
      "loss: 0.087496  [19264/25020]\n",
      "Train Error: \n",
      " Accuracy: 93.6%\n",
      "Epoch 22\n",
      "-------------------------------\n",
      "loss: 0.084780  [   64/25020]\n",
      "loss: 0.100780  [ 6464/25020]\n",
      "loss: 0.172515  [12864/25020]\n",
      "loss: 0.089151  [19264/25020]\n",
      "Train Error: \n",
      " Accuracy: 93.7%\n",
      "Epoch 23\n",
      "-------------------------------\n",
      "loss: 0.071839  [   64/25020]\n",
      "loss: 0.094323  [ 6464/25020]\n",
      "loss: 0.138125  [12864/25020]\n",
      "loss: 0.081321  [19264/25020]\n",
      "Train Error: \n",
      " Accuracy: 94.2%\n",
      "Epoch 24\n",
      "-------------------------------\n",
      "loss: 0.112051  [   64/25020]\n",
      "loss: 0.099713  [ 6464/25020]\n",
      "loss: 0.151379  [12864/25020]\n",
      "loss: 0.075381  [19264/25020]\n",
      "Train Error: \n",
      " Accuracy: 94.1%\n",
      "Epoch 25\n",
      "-------------------------------\n",
      "loss: 0.070716  [   64/25020]\n",
      "loss: 0.088435  [ 6464/25020]\n",
      "loss: 0.148643  [12864/25020]\n",
      "loss: 0.111562  [19264/25020]\n",
      "Train Error: \n",
      " Accuracy: 93.9%\n",
      "Epoch 26\n",
      "-------------------------------\n",
      "loss: 0.085514  [   64/25020]\n",
      "loss: 0.123032  [ 6464/25020]\n",
      "loss: 0.134057  [12864/25020]\n",
      "loss: 0.084148  [19264/25020]\n",
      "Train Error: \n",
      " Accuracy: 94.4%\n",
      "Epoch 27\n",
      "-------------------------------\n",
      "loss: 0.072390  [   64/25020]\n",
      "loss: 0.095546  [ 6464/25020]\n",
      "loss: 0.184528  [12864/25020]\n",
      "loss: 0.093758  [19264/25020]\n",
      "Train Error: \n",
      " Accuracy: 94.1%\n",
      "Epoch 28\n",
      "-------------------------------\n",
      "loss: 0.085965  [   64/25020]\n",
      "loss: 0.110336  [ 6464/25020]\n",
      "loss: 0.156971  [12864/25020]\n",
      "loss: 0.090958  [19264/25020]\n",
      "Train Error: \n",
      " Accuracy: 94.7%\n",
      "Epoch 29\n",
      "-------------------------------\n",
      "loss: 0.072112  [   64/25020]\n",
      "loss: 0.094834  [ 6464/25020]\n",
      "loss: 0.136261  [12864/25020]\n",
      "loss: 0.096949  [19264/25020]\n",
      "Train Error: \n",
      " Accuracy: 94.7%\n",
      "Epoch 30\n",
      "-------------------------------\n",
      "loss: 0.069347  [   64/25020]\n",
      "loss: 0.124688  [ 6464/25020]\n",
      "loss: 0.131592  [12864/25020]\n",
      "loss: 0.091233  [19264/25020]\n",
      "Train Error: \n",
      " Accuracy: 94.5%\n",
      "Epoch 31\n",
      "-------------------------------\n",
      "loss: 0.065841  [   64/25020]\n",
      "loss: 0.114560  [ 6464/25020]\n",
      "loss: 0.137989  [12864/25020]\n",
      "loss: 0.090316  [19264/25020]\n",
      "Train Error: \n",
      " Accuracy: 94.9%\n",
      "Epoch 32\n",
      "-------------------------------\n",
      "loss: 0.065472  [   64/25020]\n",
      "loss: 0.087512  [ 6464/25020]\n",
      "loss: 0.113361  [12864/25020]\n",
      "loss: 0.076608  [19264/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.1%\n",
      "Epoch 33\n",
      "-------------------------------\n",
      "loss: 0.068155  [   64/25020]\n",
      "loss: 0.082314  [ 6464/25020]\n",
      "loss: 0.124632  [12864/25020]\n",
      "loss: 0.092203  [19264/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.4%\n",
      "Epoch 34\n",
      "-------------------------------\n",
      "loss: 0.064199  [   64/25020]\n",
      "loss: 0.079064  [ 6464/25020]\n",
      "loss: 0.168298  [12864/25020]\n",
      "loss: 0.087096  [19264/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.4%\n",
      "Epoch 35\n",
      "-------------------------------\n",
      "loss: 0.060920  [   64/25020]\n",
      "loss: 0.088517  [ 6464/25020]\n",
      "loss: 0.123027  [12864/25020]\n",
      "loss: 0.088884  [19264/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.4%\n",
      "Epoch 36\n",
      "-------------------------------\n",
      "loss: 0.056989  [   64/25020]\n",
      "loss: 0.097483  [ 6464/25020]\n",
      "loss: 0.140635  [12864/25020]\n",
      "loss: 0.082483  [19264/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.5%\n",
      "Epoch 37\n",
      "-------------------------------\n",
      "loss: 0.060464  [   64/25020]\n",
      "loss: 0.102194  [ 6464/25020]\n",
      "loss: 0.157972  [12864/25020]\n",
      "loss: 0.093913  [19264/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.5%\n",
      "Epoch 38\n",
      "-------------------------------\n",
      "loss: 0.060466  [   64/25020]\n",
      "loss: 0.084376  [ 6464/25020]\n",
      "loss: 0.166104  [12864/25020]\n",
      "loss: 0.082175  [19264/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.5%\n",
      "Epoch 39\n",
      "-------------------------------\n",
      "loss: 0.064088  [   64/25020]\n",
      "loss: 0.091828  [ 6464/25020]\n",
      "loss: 0.111517  [12864/25020]\n",
      "loss: 0.073598  [19264/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.5%\n",
      "Epoch 40\n",
      "-------------------------------\n",
      "loss: 0.056548  [   64/25020]\n",
      "loss: 0.092327  [ 6464/25020]\n",
      "loss: 0.126877  [12864/25020]\n",
      "loss: 0.074704  [19264/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.5%\n",
      "Epoch 41\n",
      "-------------------------------\n",
      "loss: 0.063155  [   64/25020]\n",
      "loss: 0.081421  [ 6464/25020]\n",
      "loss: 0.132282  [12864/25020]\n",
      "loss: 0.070912  [19264/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.6%\n",
      "Epoch 42\n",
      "-------------------------------\n",
      "loss: 0.059185  [   64/25020]\n",
      "loss: 0.100046  [ 6464/25020]\n",
      "loss: 0.157865  [12864/25020]\n",
      "loss: 0.096414  [19264/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.8%\n",
      "Epoch 43\n",
      "-------------------------------\n",
      "loss: 0.058842  [   64/25020]\n",
      "loss: 0.088736  [ 6464/25020]\n",
      "loss: 0.113932  [12864/25020]\n",
      "loss: 0.073705  [19264/25020]\n",
      "Train Error: \n",
      " Accuracy: 96.0%\n",
      "Epoch 44\n",
      "-------------------------------\n",
      "loss: 0.060747  [   64/25020]\n",
      "loss: 0.081952  [ 6464/25020]\n",
      "loss: 0.133010  [12864/25020]\n",
      "loss: 0.093318  [19264/25020]\n",
      "Train Error: \n",
      " Accuracy: 96.0%\n",
      "Epoch 45\n",
      "-------------------------------\n",
      "loss: 0.057264  [   64/25020]\n",
      "loss: 0.088084  [ 6464/25020]\n",
      "loss: 0.139374  [12864/25020]\n",
      "loss: 0.093048  [19264/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.8%\n",
      "Epoch 46\n",
      "-------------------------------\n",
      "loss: 0.064944  [   64/25020]\n",
      "loss: 0.092358  [ 6464/25020]\n",
      "loss: 0.123140  [12864/25020]\n",
      "loss: 0.095485  [19264/25020]\n",
      "Train Error: \n",
      " Accuracy: 96.1%\n",
      "Epoch 47\n",
      "-------------------------------\n",
      "loss: 0.060737  [   64/25020]\n",
      "loss: 0.088396  [ 6464/25020]\n",
      "loss: 0.118139  [12864/25020]\n",
      "loss: 0.081442  [19264/25020]\n",
      "Train Error: \n",
      " Accuracy: 95.8%\n",
      "Epoch 48\n",
      "-------------------------------\n",
      "loss: 0.053220  [   64/25020]\n",
      "loss: 0.096035  [ 6464/25020]\n",
      "loss: 0.138090  [12864/25020]\n",
      "loss: 0.081876  [19264/25020]\n",
      "Train Error: \n",
      " Accuracy: 96.0%\n",
      "Epoch 49\n",
      "-------------------------------\n",
      "loss: 0.059159  [   64/25020]\n",
      "loss: 0.086248  [ 6464/25020]\n",
      "loss: 0.144802  [12864/25020]\n",
      "loss: 0.070361  [19264/25020]\n",
      "Train Error: \n",
      " Accuracy: 96.1%\n",
      "Epoch 50\n",
      "-------------------------------\n",
      "loss: 0.055205  [   64/25020]\n",
      "loss: 0.095269  [ 6464/25020]\n",
      "loss: 0.135786  [12864/25020]\n",
      "loss: 0.074466  [19264/25020]\n",
      "Train Error: \n",
      " Accuracy: 96.3%\n",
      "Done!\n"
     ]
    }
   ],
   "source": [
    "dis_train = mem_data[dist_keep[0]]\n",
    "x = X_data[mem_label]\n",
    "y = conf_out_mean\n",
    "train_data = CustomDataset(x, y, model_transform)\n",
    "train_dataloader = DataLoader(train_data, batch_size=batch_size)\n",
    "TargetModel = globals()['create_{}_model'.format(model)](10, data_name)\n",
    "TargetModel.to(device)\n",
    "loss_fn = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.Adam(TargetModel.parameters(), lr=LEARNING_RATE)\n",
    "for t in range(epochs):\n",
    "    print(f\"Epoch {t+1}\\n-------------------------------\")\n",
    "    train_onehot(train_dataloader, TargetModel, loss_fn, optimizer, device)\n",
    "print(\"Done!\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c934ad3-68c9-4d80-9d3e-81b883adc11c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f63ea41-2c5d-459e-b0bb-76de6923cc70",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "b9ff7317-a288-42a9-beb2-b90b95087e53",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "test_dataset = datasets.cifar.CIFAR10(root='../datasets/cifar10', train=False, transform=None, download=True)\n",
    "x_test_data = test_dataset.data\n",
    "y_test_data = np.array(test_dataset.targets)\n",
    "test_data = CustomDataset(x_test_data, y_test_data, model_transform)\n",
    "test_dataloader = DataLoader(test_data, batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c3367bb-46fd-421e-a9a2-663332585761",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "77c2ee1f-142a-4111-bab2-83d9cda674bb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test Error: \n",
      " Accuracy: 70.7%, Avg loss: 1.352945 \n",
      "\n"
     ]
    }
   ],
   "source": [
    "evaluate(test_dataloader, TargetModel, loss_fn, device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3178802c-af4c-45eb-8325-b52fe6053573",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "d90d1b0b-b4d6-4c9b-9944-84de61331f73",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "CNN(\n",
       "  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "  (fc1): Linear(in_features=1024, out_features=500, bias=True)\n",
       "  (fc2): Linear(in_features=500, out_features=10, bias=True)\n",
       "  (BatchNorm1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (BatchNorm2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (BatchNorm3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       ")"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "CompareModel = globals()['create_{}_model'.format(model)](10, data_name)\n",
    "weight_path = os.path.join(weight_dir, \"{}_{}_epoch{}_model{}.pth\".format(data_name, model, epochs, 0))\n",
    "# print(Reference_Model)\n",
    "CompareModel.load_state_dict(torch.load(weight_path))\n",
    "CompareModel.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "26ef4406-a132-47bc-8a8a-198e16b4c276",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test Error: \n",
      " Accuracy: 71.2%, Avg loss: 2.209261 \n",
      "\n"
     ]
    }
   ],
   "source": [
    "evaluate(test_dataloader, CompareModel, loss_fn, device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13eeff50-d7c2-4cd8-b123-3671c2c894ab",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7d07e76-0fd3-4198-a06a-4e90e237afea",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "1348df97-7ce7-4df9-a70b-f4391ecc75cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_data_all = np.load('CIFAR10_loss.npy')\n",
    "score_all = np.load('CIFAR10_score.npy')\n",
    "conf_data_all = np.load('CIFAR10_conf.npy')\n",
    "pri_risk_all = get_risk_score(loss_data_all, train_keep)\n",
    "pri_risk_rank = np.argsort(pri_risk_all)\n",
    "pri_risk_rank = np.flip(pri_risk_rank)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "eefa70b3-5c4a-411b-be8f-e8110d9ee606",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test Error: \n",
      " Accuracy: 73.9%, Avg loss: 1.292915 \n",
      "\n",
      "0.504\n"
     ]
    }
   ],
   "source": [
    "loss_fn = nn.CrossEntropyLoss()\n",
    "pred_result = base_attack(all_dataloader, TargetModel, loss_fn, device)\n",
    "\n",
    "pred_clip = pred_result[pri_risk_rank[:500]]\n",
    "mem_clip = train_keep[0][pri_risk_rank[:500]]\n",
    "accuracy = metrics.accuracy_score(mem_clip, pred_clip)\n",
    "print(accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d088f212-08ef-4811-ae42-6f08c1baef29",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "d621c9f4-f319-4163-94eb-742672163287",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test Error: \n",
      " Accuracy: 84.7%, Avg loss: 1.124064 \n",
      "\n",
      "0.952\n"
     ]
    }
   ],
   "source": [
    "loss_fn = nn.CrossEntropyLoss()\n",
    "pred_result = base_attack(all_dataloader, CompareModel, loss_fn, device)\n",
    "\n",
    "pred_clip = pred_result[pri_risk_rank[:500]]\n",
    "mem_clip = train_keep[0][pri_risk_rank[:500]]\n",
    "accuracy = metrics.accuracy_score(mem_clip, pred_clip)\n",
    "print(accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e832d37-c2c8-44dc-a936-a943e6ccb09f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25a70ea4-fa37-4e5b-862d-b6b5fd4c0a81",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa8ab569-2cd8-439d-a3d5-c1af27402d76",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5abd032a-d580-4bdd-8955-694a1da90828",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "3a5d77e1-87d4-4374-8c35-e13b2402ad53",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "AUC value is: 0.7769893964732137\n",
      "Accuracy is: 0.69174\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.69174"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred_result = LIRA_attack(train_keep, score_all, score_all[0], train_keep[0])\n",
    "evaluate_ROC(pred_result, train_keep[0], threshold=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "41b8a152-d958-4da8-8607-c9309b4ff453",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.954\n"
     ]
    }
   ],
   "source": [
    "pred_clip = pred_result[pri_risk_rank[:500]]\n",
    "mem_clip = train_keep[0][pri_risk_rank[:500]]\n",
    "pred_clip = pred_clip > 0\n",
    "accuracy = metrics.accuracy_score(mem_clip, pred_clip)\n",
    "print(accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "383fdf4d-9ba8-43e2-b291-19ddd9d0469e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59c57604-07df-4131-98a2-1d6890d11729",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8eec9979-5335-4d83-91cf-d148bb3a8a6d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2bb0bff0-4c4e-4500-9b3b-86d8694d4b3e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7b070bb-9530-4904-9c56-113874ceac09",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b58a2563-1783-482f-89ae-dfb1ff81ef9a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84c5fff7-5099-4424-a912-e4f3747bf615",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e2d627e-a382-4fad-930b-5dd87d6186f4",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a7ee668-31bc-49c4-959d-34312b85f9ce",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "947d7f86-4f93-4c20-a8b6-c7110b57e064",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "633bed5d-5544-4041-a1b3-e410297a7701",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2c60dba-5775-43f8-ad72-3dd96a5daf63",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f433faa5-2210-4af3-8401-ae355ffa7557",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "651c4841-17b8-415f-b04f-c1431513a4c5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f5b8813-157c-4b45-a526-46da14035c57",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b55bc42-ec4f-4d0d-9cb1-ef76e4fb1e65",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a99fdfd-a0a4-4058-9b80-3e203cc6caad",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "895cfac8-4348-4444-91ce-c0da64c6b78d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74e4a851-26aa-4d74-b491-4fdc8d836c6d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19e0ad90-2114-4dd3-b320-7297c5abfebf",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac77c100-03ce-4495-9195-10a75167f080",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3655005-f5d7-476d-9c18-dc42e6c2741b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "1754e859-f23b-4a3c-8666-b2ce06c2b8f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = CustomDataset(X_data, Y_data, model_transform)\n",
    "train_dataloader = DataLoader(train_data, batch_size=64, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "0e839f4d-93e7-48f0-af07-349ebef049a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = BATCH_SIZE\n",
    "model = MODEL\n",
    "epochs = EPOCHS\n",
    "data_name = DATA_NAME "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8bf4ae41-482b-464b-b2df-f79924f00fe2",
   "metadata": {},
   "source": [
    "## 脆弱点的两种提取方式"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a20f6954-5154-483b-a64b-8ab536b4989b",
   "metadata": {},
   "source": [
    "### 风险指标"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "90586eba-a7f4-4a19-a8a8-54d5d8633693",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 加载所有参考模型上的损失、置信度、得分输出\n",
    "# conf_data_all, label_data, score_all = load_score_data_all(X_data, Y_data, weight_dir, num_shadowsets, data_name, model, epochs, model_transform, batch_size, device)\n",
    "# loss_fn = nn.CrossEntropyLoss(reduction='none')\n",
    "# loss_data_all, label_data = load_loss_data_all(X_data, Y_data, loss_fn, weight_dir, num_shadowsets, data_name, model, epochs, model_transform, batch_size, device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "5fbe2bf6-1750-4c23-91a8-80caf9b0825e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# np.save('CIFAR10_loss.npy', loss_data_all)\n",
    "# np.save('CIFAR10_score.npy', score_all)\n",
    "# np.save('CIFAR10_conf.npy', conf_data_all)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "5a6e9c56-60bd-4bfc-a912-5e3b5eb52a4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_data_all = np.load('CIFAR10_loss.npy')\n",
    "score_all = np.load('CIFAR10_score.npy')\n",
    "conf_data_all = np.load('CIFAR10_conf.npy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "8b45bd6b-c391-4ced-99bd-57b0b8c3c30b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 为每个数据点计算风险指标\n",
    "# 计算出一个点的脆弱程度评分\n",
    "pri_risk_all = get_risk_score(loss_data_all, train_keep)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "d8258b9c-d6f3-4e64-ab4f-4b068bb4b209",
   "metadata": {},
   "outputs": [],
   "source": [
    "pri_risk_rank = np.argsort(pri_risk_all)\n",
    "pri_risk_rank = np.flip(pri_risk_rank)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "16cc7233-a9a0-4e74-990d-70e0e6ebfb74",
   "metadata": {},
   "source": [
    "### 离群点"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "e7e42cbf-89f1-4d1f-bbd0-dc87252c22cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 准备好logits的输出\n",
    "# 计算余弦相似度 5w*5w的大型矩阵\n",
    "# 邻居距离alpha，邻居数量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "8d4c50e3-0ca3-463c-ad2a-4e4fd987b12e",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# logits_data_all, label_data = load_logits_data_all(X_data, Y_data, weight_dir, num_shadowsets, data_name, model, epochs, model_transform, batch_size, device)\n",
    "# np.save('CIFAR10_logits.npy', logits_data_all)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "a9314bc7-f529-46f5-834b-d1c387effef7",
   "metadata": {},
   "outputs": [],
   "source": [
    "logits_data_all = np.load('CIFAR10_logits.npy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "50d2365f-9e9d-4ffc-a9f7-16c1d3ed5646",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 按照k个模型进行拼接\n",
    "# k = 10\n",
    "# for i in range(k):\n",
    "#     if i == 0:\n",
    "#         combine_features = logits_data_all[i]\n",
    "#     else:\n",
    "#         combine_features = np.concatenate((combine_features, logits_data_all[i]),axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "1308ab27-3798-4a3c-8dd2-b5ce0234042e",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# # 数据量太大，不能保存所有的余弦相似度，只能需要时计算\n",
    "# alpha_list = [0.05, 0.1, 0.12, 0.15, 0.2, 0.3]\n",
    "# n_num_list = []\n",
    "# # for i in range(combine_features.shape[0]):\n",
    "# for i in range(10000):\n",
    "#     n_count = [0 for _ in alpha_list]\n",
    "#     if i%50 == 0:\n",
    "#         print(f\"compute to: {i}\")\n",
    "#     for j in range(combine_features.shape[0]):\n",
    "#         # 余弦距离的计算\n",
    "#         vec1 = combine_features[i]\n",
    "#         vec2 = combine_features[j]        \n",
    "#         cos_sim = vec1.dot(vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))\n",
    "#         cos_dis = 0.5 - 0.5 * cos_sim\n",
    "#         for m in range(len(alpha_list)):\n",
    "#             if (cos_dis < alpha_list[m]):\n",
    "#                 n_count[m] += 1\n",
    "#     n_num_list.append(n_count)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "ae886e98-bf25-4c0d-a2fa-f2c948574aa8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# neigh_data_all = np.array(n_num_list)\n",
    "# np.save('CIFAR10_neigh.npy', neigh_data_all)\n",
    "neigh_data_all = np.load('CIFAR10_neigh.npy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "79592536-1181-4647-9e40-fc959daa668e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(50000, 6)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "neigh_data_all.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "9d815b0b-568b-46d5-bdeb-fba567a8e213",
   "metadata": {},
   "outputs": [],
   "source": [
    "neigh_num = neigh_data_all[:,1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "7bc974ea-ffb5-432f-b9da-ab7abce10270",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(50000,)"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "neigh_num.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "ddd4b8fa-1653-4bab-bf8f-4e2a116b17b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "risk_rank = np.argsort(neigh_num)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "789a839a-cb2b-4da7-a1a2-ecac971344ef",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([29631, 39852, 26325, ..., 47315, 41363, 17410])"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "risk_rank"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cce6c5e2-b030-4874-b08e-bf8466815484",
   "metadata": {},
   "source": [
    "## 针对脆弱点展开攻击"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "14cb98e0-737b-4ed0-93bd-a52681974606",
   "metadata": {},
   "source": [
    "### 基线攻击"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "8c66acdc-645a-4cf8-b0f6-babc51bce18a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 预测正确的判断为成员，预测不正确的判断为非成员"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "ecc0a131-e71a-4cd1-b5a3-a6c750abc382",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "CNN(\n",
       "  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "  (fc1): Linear(in_features=1024, out_features=500, bias=True)\n",
       "  (fc2): Linear(in_features=500, out_features=10, bias=True)\n",
       "  (BatchNorm1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (BatchNorm2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (BatchNorm3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       ")"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 创建对应的目标模型\n",
    "if model in ['NN', 'NN_4layer']:\n",
    "    Target_Model = globals()['create_{}_model'.format(model)](X_data.shape[1], Y_data.max()+1)\n",
    "elif model == 'CNN':\n",
    "    Target_Model = globals()['create_{}_model'.format(model)](Y_data.max()+1, data_name)\n",
    "# 加载参数\n",
    "weight_path = os.path.join(weight_dir, \"{}_{}_epoch{}_model{}.pth\".format(data_name, model, epochs, tar_model))\n",
    "# print(Reference_Model)\n",
    "Target_Model.load_state_dict(torch.load(weight_path))\n",
    "Target_Model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "670ebafe-5710-4b0c-b05d-008a6098dbc8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test Error: \n",
      " Accuracy: 84.7%, Avg loss: 1.124064 \n",
      "\n"
     ]
    }
   ],
   "source": [
    "loss_fn = nn.CrossEntropyLoss()\n",
    "pred_result = base_attack(train_dataloader, Target_Model, loss_fn, device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "d0dfd319-c536-4a57-a2b9-246f9a577f60",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.8\n",
      "0.916\n"
     ]
    }
   ],
   "source": [
    "pred_clip = pred_result[risk_rank[:5000]]\n",
    "mem_clip = train_keep[0][risk_rank[:5000]]\n",
    "accuracy = metrics.accuracy_score(mem_clip, pred_clip)\n",
    "print(accuracy)\n",
    "\n",
    "pred_clip = pred_result[pri_risk_rank[:5000]]\n",
    "mem_clip = train_keep[0][pri_risk_rank[:5000]]\n",
    "accuracy = metrics.accuracy_score(mem_clip, pred_clip)\n",
    "print(accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c08b9a40-068e-4411-8a51-13a05337a7cd",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "5cf3f668-0b0b-41b1-bc5f-73f89bed9169",
   "metadata": {},
   "source": [
    "### 阈值攻击"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "61729130-6c97-4a19-b7c5-767f6ba1ae70",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 基于损失的阈值去做攻击，阈值如何确定？两个均值的均值"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "0dd9f430-c918-47cd-9e8d-4ee1b35bf143",
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_threshold = get_loss_threshold(loss_data_all, train_keep)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "18f8fd0e-f86a-47fc-8105-7a3a53b10b28",
   "metadata": {},
   "outputs": [],
   "source": [
    "target_loss = loss_data_all[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "09270911-4d24-4c91-9c59-224213d1c319",
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_result = target_loss < loss_threshold"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "ef57e35b-e721-4cea-ad31-6b7cc4ef5362",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.7938\n",
      "0.8614\n"
     ]
    }
   ],
   "source": [
    "pred_clip = pred_result[risk_rank[:5000]]\n",
    "mem_clip = train_keep[0][risk_rank[:5000]]\n",
    "accuracy = metrics.accuracy_score(mem_clip, pred_clip)\n",
    "print(accuracy)\n",
    "\n",
    "pred_clip = pred_result[pri_risk_rank[:5000]]\n",
    "mem_clip = train_keep[0][pri_risk_rank[:5000]]\n",
    "accuracy = metrics.accuracy_score(mem_clip, pred_clip)\n",
    "print(accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "736cdbcb-9dd3-45ab-bcc4-37c7f490a588",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8259bc0c-fbc1-4025-bf4a-b194236a39ff",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "786da5f1-1ae9-4f9b-ae7f-e18702ea84a7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "77876b58-669a-483a-b80c-8b041486eda0",
   "metadata": {},
   "source": [
    "### 似然比攻击"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "f1304cdd-cd2c-437f-af20-bdd3eb68caab",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 先对所有目标数据执行攻击，然后根据脆弱点筛选获取对应的攻击成功率或者ROC\n",
    "# 输出两个，memlabel和pred_result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "c9ed8ba5-393b-4bdd-8a65-bbd6bbcaf4b6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "AUC value is: 0.7769893964732137\n",
      "Accuracy is: 0.69174\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.69174"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred_result = LIRA_attack(train_keep, score_all, score_all[0], train_keep[0])\n",
    "evaluate_ROC(pred_result, train_keep[0], threshold=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "5fdb2095-a077-40a4-b954-c4947b015080",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.8458\n",
      "0.9234\n"
     ]
    }
   ],
   "source": [
    "pred_clip = pred_result[risk_rank[:5000]]\n",
    "mem_clip = train_keep[0][risk_rank[:5000]]\n",
    "pred_clip = pred_clip > 0\n",
    "accuracy = metrics.accuracy_score(mem_clip, pred_clip)\n",
    "print(accuracy)\n",
    "\n",
    "pred_clip = pred_result[pri_risk_rank[:5000]]\n",
    "mem_clip = train_keep[0][pri_risk_rank[:5000]]\n",
    "pred_clip = pred_clip > 0\n",
    "accuracy = metrics.accuracy_score(mem_clip, pred_clip)\n",
    "print(accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35c96a5b-843d-4cb6-b964-f13dadac3193",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26be2f5a-29e3-45c7-b6f0-c5382828f63f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "457be14e-4b30-488a-aab0-0ae0b345e27e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "8ea3f420-9a04-420f-83b6-06ea1658f6c7",
   "metadata": {},
   "source": [
    "### 影子模型攻击"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "7c5e0038-e2b6-4997-a64f-c9e80ed2c9c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 在所有数据上执行一次攻击"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "7d7fefb0-3e20-40e2-be99-4527c0f62306",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      " Error: \n",
      " Accuracy: 98.2%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 70.7%  \n",
      "\n",
      "(50000, 10) (50000,) (50000,)\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      " Error: \n",
      " Accuracy: 98.4%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 70.9%  \n",
      "\n",
      "(50000, 10) (50000,) (50000,)\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      " Error: \n",
      " Accuracy: 97.3%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 70.7%  \n",
      "\n",
      "(50000, 10) (50000,) (50000,)\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      " Error: \n",
      " Accuracy: 98.5%  \n",
      "\n",
      " Error: \n",
      " Accuracy: 70.9%  \n",
      "\n",
      "test data: (50000, 10) (50000,) (50000,)\n",
      "(150000, 10) (150000,)\n",
      "Attack_NN(\n",
      "  (linear_relu_stack): Sequential(\n",
      "    (0): Linear(in_features=11, out_features=128, bias=True)\n",
      "    (1): ReLU()\n",
      "    (2): Linear(in_features=128, out_features=64, bias=True)\n",
      "    (3): ReLU()\n",
      "    (4): Linear(in_features=64, out_features=1, bias=True)\n",
      "  )\n",
      ")\n",
      "Epoch 1\n",
      "-------------------------------\n",
      "loss: 0.697296  [   64/150000]\n",
      "loss: 0.691852  [ 6464/150000]\n",
      "loss: 0.687517  [12864/150000]\n",
      "loss: 0.651151  [19264/150000]\n",
      "loss: 0.638183  [25664/150000]\n",
      "loss: 0.665337  [32064/150000]\n",
      "loss: 0.629165  [38464/150000]\n",
      "loss: 0.629897  [44864/150000]\n",
      "loss: 0.633173  [51264/150000]\n",
      "loss: 0.613899  [57664/150000]\n",
      "loss: 0.603700  [64064/150000]\n",
      "loss: 0.624719  [70464/150000]\n",
      "loss: 0.607570  [76864/150000]\n",
      "loss: 0.615966  [83264/150000]\n",
      "loss: 0.594782  [89664/150000]\n",
      "loss: 0.612963  [96064/150000]\n",
      "loss: 0.666927  [102464/150000]\n",
      "loss: 0.635343  [108864/150000]\n",
      "loss: 0.551613  [115264/150000]\n",
      "loss: 0.614503  [121664/150000]\n",
      "loss: 0.609472  [128064/150000]\n",
      "loss: 0.651940  [134464/150000]\n",
      "loss: 0.623228  [140864/150000]\n",
      "loss: 0.632289  [147264/150000]\n",
      "Epoch 2\n",
      "-------------------------------\n",
      "loss: 0.597696  [   64/150000]\n",
      "loss: 0.666137  [ 6464/150000]\n",
      "loss: 0.574050  [12864/150000]\n",
      "loss: 0.663844  [19264/150000]\n",
      "loss: 0.615269  [25664/150000]\n",
      "loss: 0.637435  [32064/150000]\n",
      "loss: 0.607537  [38464/150000]\n",
      "loss: 0.640098  [44864/150000]\n",
      "loss: 0.555198  [51264/150000]\n",
      "loss: 0.508200  [57664/150000]\n",
      "loss: 0.649726  [64064/150000]\n",
      "loss: 0.579738  [70464/150000]\n",
      "loss: 0.639545  [76864/150000]\n",
      "loss: 0.595912  [83264/150000]\n",
      "loss: 0.603856  [89664/150000]\n",
      "loss: 0.624075  [96064/150000]\n",
      "loss: 0.644577  [102464/150000]\n",
      "loss: 0.573552  [108864/150000]\n",
      "loss: 0.593017  [115264/150000]\n",
      "loss: 0.678881  [121664/150000]\n",
      "loss: 0.645067  [128064/150000]\n",
      "loss: 0.678110  [134464/150000]\n",
      "loss: 0.592467  [140864/150000]\n",
      "loss: 0.644923  [147264/150000]\n",
      "Epoch 3\n",
      "-------------------------------\n",
      "loss: 0.549876  [   64/150000]\n",
      "loss: 0.540238  [ 6464/150000]\n",
      "loss: 0.602370  [12864/150000]\n",
      "loss: 0.577029  [19264/150000]\n",
      "loss: 0.573094  [25664/150000]\n",
      "loss: 0.556609  [32064/150000]\n",
      "loss: 0.588493  [38464/150000]\n",
      "loss: 0.620299  [44864/150000]\n",
      "loss: 0.532603  [51264/150000]\n",
      "loss: 0.550144  [57664/150000]\n",
      "loss: 0.655471  [64064/150000]\n",
      "loss: 0.557127  [70464/150000]\n",
      "loss: 0.562649  [76864/150000]\n",
      "loss: 0.611032  [83264/150000]\n",
      "loss: 0.609907  [89664/150000]\n",
      "loss: 0.641321  [96064/150000]\n",
      "loss: 0.517016  [102464/150000]\n",
      "loss: 0.548020  [108864/150000]\n",
      "loss: 0.606035  [115264/150000]\n",
      "loss: 0.712533  [121664/150000]\n",
      "loss: 0.600617  [128064/150000]\n",
      "loss: 0.581069  [134464/150000]\n",
      "loss: 0.684818  [140864/150000]\n",
      "loss: 0.590795  [147264/150000]\n",
      "Epoch 4\n",
      "-------------------------------\n",
      "loss: 0.614286  [   64/150000]\n",
      "loss: 0.603604  [ 6464/150000]\n",
      "loss: 0.656957  [12864/150000]\n",
      "loss: 0.724023  [19264/150000]\n",
      "loss: 0.610916  [25664/150000]\n",
      "loss: 0.620281  [32064/150000]\n",
      "loss: 0.583793  [38464/150000]\n",
      "loss: 0.590866  [44864/150000]\n",
      "loss: 0.590776  [51264/150000]\n",
      "loss: 0.535479  [57664/150000]\n",
      "loss: 0.576726  [64064/150000]\n",
      "loss: 0.629341  [70464/150000]\n",
      "loss: 0.564287  [76864/150000]\n",
      "loss: 0.611278  [83264/150000]\n",
      "loss: 0.531345  [89664/150000]\n",
      "loss: 0.570144  [96064/150000]\n",
      "loss: 0.544331  [102464/150000]\n",
      "loss: 0.641698  [108864/150000]\n",
      "loss: 0.555752  [115264/150000]\n",
      "loss: 0.562733  [121664/150000]\n",
      "loss: 0.679568  [128064/150000]\n",
      "loss: 0.604731  [134464/150000]\n",
      "loss: 0.575098  [140864/150000]\n",
      "loss: 0.597137  [147264/150000]\n",
      "Epoch 5\n",
      "-------------------------------\n",
      "loss: 0.557888  [   64/150000]\n",
      "loss: 0.536265  [ 6464/150000]\n",
      "loss: 0.493845  [12864/150000]\n",
      "loss: 0.616729  [19264/150000]\n",
      "loss: 0.629660  [25664/150000]\n",
      "loss: 0.672249  [32064/150000]\n",
      "loss: 0.625117  [38464/150000]\n",
      "loss: 0.594650  [44864/150000]\n",
      "loss: 0.611802  [51264/150000]\n",
      "loss: 0.610690  [57664/150000]\n",
      "loss: 0.582970  [64064/150000]\n",
      "loss: 0.507220  [70464/150000]\n",
      "loss: 0.587091  [76864/150000]\n",
      "loss: 0.644539  [83264/150000]\n",
      "loss: 0.567560  [89664/150000]\n",
      "loss: 0.635086  [96064/150000]\n",
      "loss: 0.540519  [102464/150000]\n",
      "loss: 0.587662  [108864/150000]\n",
      "loss: 0.594138  [115264/150000]\n",
      "loss: 0.623967  [121664/150000]\n",
      "loss: 0.630961  [128064/150000]\n",
      "loss: 0.540283  [134464/150000]\n",
      "loss: 0.660469  [140864/150000]\n",
      "loss: 0.615367  [147264/150000]\n",
      "Epoch 6\n",
      "-------------------------------\n",
      "loss: 0.644627  [   64/150000]\n",
      "loss: 0.641388  [ 6464/150000]\n",
      "loss: 0.583667  [12864/150000]\n",
      "loss: 0.595985  [19264/150000]\n",
      "loss: 0.639247  [25664/150000]\n",
      "loss: 0.553457  [32064/150000]\n",
      "loss: 0.644644  [38464/150000]\n",
      "loss: 0.659022  [44864/150000]\n",
      "loss: 0.603649  [51264/150000]\n",
      "loss: 0.556295  [57664/150000]\n",
      "loss: 0.633961  [64064/150000]\n",
      "loss: 0.603780  [70464/150000]\n",
      "loss: 0.635701  [76864/150000]\n",
      "loss: 0.617778  [83264/150000]\n",
      "loss: 0.576959  [89664/150000]\n",
      "loss: 0.612733  [96064/150000]\n",
      "loss: 0.606709  [102464/150000]\n",
      "loss: 0.534224  [108864/150000]\n",
      "loss: 0.552378  [115264/150000]\n",
      "loss: 0.582934  [121664/150000]\n",
      "loss: 0.590858  [128064/150000]\n",
      "loss: 0.572298  [134464/150000]\n",
      "loss: 0.577786  [140864/150000]\n",
      "loss: 0.601727  [147264/150000]\n",
      "Epoch 7\n",
      "-------------------------------\n",
      "loss: 0.585471  [   64/150000]\n",
      "loss: 0.620201  [ 6464/150000]\n",
      "loss: 0.629633  [12864/150000]\n",
      "loss: 0.633278  [19264/150000]\n",
      "loss: 0.581919  [25664/150000]\n",
      "loss: 0.639833  [32064/150000]\n",
      "loss: 0.565880  [38464/150000]\n",
      "loss: 0.618099  [44864/150000]\n",
      "loss: 0.540124  [51264/150000]\n",
      "loss: 0.533662  [57664/150000]\n",
      "loss: 0.537237  [64064/150000]\n",
      "loss: 0.654931  [70464/150000]\n",
      "loss: 0.579782  [76864/150000]\n",
      "loss: 0.584689  [83264/150000]\n",
      "loss: 0.596409  [89664/150000]\n",
      "loss: 0.616381  [96064/150000]\n",
      "loss: 0.590769  [102464/150000]\n",
      "loss: 0.616326  [108864/150000]\n",
      "loss: 0.685181  [115264/150000]\n",
      "loss: 0.593606  [121664/150000]\n",
      "loss: 0.644492  [128064/150000]\n",
      "loss: 0.657713  [134464/150000]\n",
      "loss: 0.686037  [140864/150000]\n",
      "loss: 0.612704  [147264/150000]\n",
      "Epoch 8\n",
      "-------------------------------\n",
      "loss: 0.565326  [   64/150000]\n",
      "loss: 0.596956  [ 6464/150000]\n",
      "loss: 0.633713  [12864/150000]\n",
      "loss: 0.675944  [19264/150000]\n",
      "loss: 0.558952  [25664/150000]\n",
      "loss: 0.647076  [32064/150000]\n",
      "loss: 0.617525  [38464/150000]\n",
      "loss: 0.561403  [44864/150000]\n",
      "loss: 0.656842  [51264/150000]\n",
      "loss: 0.644154  [57664/150000]\n",
      "loss: 0.543379  [64064/150000]\n",
      "loss: 0.672034  [70464/150000]\n",
      "loss: 0.677683  [76864/150000]\n",
      "loss: 0.586226  [83264/150000]\n",
      "loss: 0.612088  [89664/150000]\n",
      "loss: 0.554074  [96064/150000]\n",
      "loss: 0.608314  [102464/150000]\n",
      "loss: 0.596135  [108864/150000]\n",
      "loss: 0.572143  [115264/150000]\n",
      "loss: 0.575213  [121664/150000]\n",
      "loss: 0.617696  [128064/150000]\n",
      "loss: 0.516614  [134464/150000]\n",
      "loss: 0.608312  [140864/150000]\n",
      "loss: 0.655292  [147264/150000]\n",
      "Epoch 9\n",
      "-------------------------------\n",
      "loss: 0.632591  [   64/150000]\n",
      "loss: 0.573968  [ 6464/150000]\n",
      "loss: 0.606144  [12864/150000]\n",
      "loss: 0.570658  [19264/150000]\n",
      "loss: 0.582340  [25664/150000]\n",
      "loss: 0.686753  [32064/150000]\n",
      "loss: 0.566058  [38464/150000]\n",
      "loss: 0.554912  [44864/150000]\n",
      "loss: 0.565513  [51264/150000]\n",
      "loss: 0.685594  [57664/150000]\n",
      "loss: 0.577988  [64064/150000]\n",
      "loss: 0.601425  [70464/150000]\n",
      "loss: 0.529391  [76864/150000]\n",
      "loss: 0.522878  [83264/150000]\n",
      "loss: 0.630615  [89664/150000]\n",
      "loss: 0.667440  [96064/150000]\n",
      "loss: 0.540676  [102464/150000]\n",
      "loss: 0.535263  [108864/150000]\n",
      "loss: 0.621468  [115264/150000]\n",
      "loss: 0.637567  [121664/150000]\n",
      "loss: 0.629170  [128064/150000]\n",
      "loss: 0.606112  [134464/150000]\n",
      "loss: 0.571968  [140864/150000]\n",
      "loss: 0.615397  [147264/150000]\n",
      "Epoch 10\n",
      "-------------------------------\n",
      "loss: 0.605752  [   64/150000]\n",
      "loss: 0.573386  [ 6464/150000]\n",
      "loss: 0.575619  [12864/150000]\n",
      "loss: 0.593930  [19264/150000]\n",
      "loss: 0.606027  [25664/150000]\n",
      "loss: 0.603035  [32064/150000]\n",
      "loss: 0.534000  [38464/150000]\n",
      "loss: 0.565585  [44864/150000]\n",
      "loss: 0.589906  [51264/150000]\n",
      "loss: 0.592201  [57664/150000]\n",
      "loss: 0.628924  [64064/150000]\n",
      "loss: 0.616412  [70464/150000]\n",
      "loss: 0.587669  [76864/150000]\n",
      "loss: 0.608327  [83264/150000]\n",
      "loss: 0.583130  [89664/150000]\n",
      "loss: 0.543986  [96064/150000]\n",
      "loss: 0.607370  [102464/150000]\n",
      "loss: 0.546562  [108864/150000]\n",
      "loss: 0.591624  [115264/150000]\n",
      "loss: 0.609809  [121664/150000]\n",
      "loss: 0.578402  [128064/150000]\n",
      "loss: 0.595384  [134464/150000]\n",
      "loss: 0.625442  [140864/150000]\n",
      "loss: 0.665480  [147264/150000]\n",
      "Epoch 11\n",
      "-------------------------------\n",
      "loss: 0.579572  [   64/150000]\n",
      "loss: 0.677842  [ 6464/150000]\n",
      "loss: 0.593931  [12864/150000]\n",
      "loss: 0.592321  [19264/150000]\n",
      "loss: 0.575612  [25664/150000]\n",
      "loss: 0.571065  [32064/150000]\n",
      "loss: 0.575749  [38464/150000]\n",
      "loss: 0.616480  [44864/150000]\n",
      "loss: 0.472566  [51264/150000]\n",
      "loss: 0.550193  [57664/150000]\n",
      "loss: 0.557135  [64064/150000]\n",
      "loss: 0.567849  [70464/150000]\n",
      "loss: 0.661146  [76864/150000]\n",
      "loss: 0.578843  [83264/150000]\n",
      "loss: 0.588202  [89664/150000]\n",
      "loss: 0.639270  [96064/150000]\n",
      "loss: 0.718008  [102464/150000]\n",
      "loss: 0.622295  [108864/150000]\n",
      "loss: 0.628525  [115264/150000]\n",
      "loss: 0.623000  [121664/150000]\n",
      "loss: 0.606138  [128064/150000]\n",
      "loss: 0.554370  [134464/150000]\n",
      "loss: 0.689069  [140864/150000]\n",
      "loss: 0.555399  [147264/150000]\n",
      "Epoch 12\n",
      "-------------------------------\n",
      "loss: 0.552356  [   64/150000]\n",
      "loss: 0.686734  [ 6464/150000]\n",
      "loss: 0.609447  [12864/150000]\n",
      "loss: 0.604392  [19264/150000]\n",
      "loss: 0.602954  [25664/150000]\n",
      "loss: 0.546178  [32064/150000]\n",
      "loss: 0.628506  [38464/150000]\n",
      "loss: 0.673439  [44864/150000]\n",
      "loss: 0.625517  [51264/150000]\n",
      "loss: 0.598829  [57664/150000]\n",
      "loss: 0.568111  [64064/150000]\n",
      "loss: 0.698848  [70464/150000]\n",
      "loss: 0.625230  [76864/150000]\n",
      "loss: 0.600533  [83264/150000]\n",
      "loss: 0.591708  [89664/150000]\n",
      "loss: 0.605737  [96064/150000]\n",
      "loss: 0.617584  [102464/150000]\n",
      "loss: 0.605771  [108864/150000]\n",
      "loss: 0.697467  [115264/150000]\n",
      "loss: 0.591232  [121664/150000]\n",
      "loss: 0.552443  [128064/150000]\n",
      "loss: 0.592739  [134464/150000]\n",
      "loss: 0.589756  [140864/150000]\n",
      "loss: 0.534506  [147264/150000]\n",
      "Epoch 13\n",
      "-------------------------------\n",
      "loss: 0.640633  [   64/150000]\n",
      "loss: 0.697667  [ 6464/150000]\n",
      "loss: 0.592239  [12864/150000]\n",
      "loss: 0.595909  [19264/150000]\n",
      "loss: 0.538048  [25664/150000]\n",
      "loss: 0.567052  [32064/150000]\n",
      "loss: 0.565347  [38464/150000]\n",
      "loss: 0.628676  [44864/150000]\n",
      "loss: 0.532067  [51264/150000]\n",
      "loss: 0.569317  [57664/150000]\n",
      "loss: 0.662629  [64064/150000]\n",
      "loss: 0.690312  [70464/150000]\n",
      "loss: 0.549885  [76864/150000]\n",
      "loss: 0.621373  [83264/150000]\n",
      "loss: 0.593378  [89664/150000]\n",
      "loss: 0.655877  [96064/150000]\n",
      "loss: 0.596402  [102464/150000]\n",
      "loss: 0.670918  [108864/150000]\n",
      "loss: 0.597230  [115264/150000]\n",
      "loss: 0.592003  [121664/150000]\n",
      "loss: 0.618691  [128064/150000]\n",
      "loss: 0.580073  [134464/150000]\n",
      "loss: 0.650928  [140864/150000]\n",
      "loss: 0.595747  [147264/150000]\n",
      "Epoch 14\n",
      "-------------------------------\n",
      "loss: 0.631570  [   64/150000]\n",
      "loss: 0.581386  [ 6464/150000]\n",
      "loss: 0.530600  [12864/150000]\n",
      "loss: 0.649409  [19264/150000]\n",
      "loss: 0.652399  [25664/150000]\n",
      "loss: 0.567224  [32064/150000]\n",
      "loss: 0.562777  [38464/150000]\n",
      "loss: 0.475368  [44864/150000]\n",
      "loss: 0.525190  [51264/150000]\n",
      "loss: 0.599928  [57664/150000]\n",
      "loss: 0.548746  [64064/150000]\n",
      "loss: 0.474745  [70464/150000]\n",
      "loss: 0.626452  [76864/150000]\n",
      "loss: 0.503386  [83264/150000]\n",
      "loss: 0.494949  [89664/150000]\n",
      "loss: 0.615882  [96064/150000]\n",
      "loss: 0.583681  [102464/150000]\n",
      "loss: 0.605108  [108864/150000]\n",
      "loss: 0.607096  [115264/150000]\n",
      "loss: 0.582837  [121664/150000]\n",
      "loss: 0.605536  [128064/150000]\n",
      "loss: 0.604692  [134464/150000]\n",
      "loss: 0.584882  [140864/150000]\n",
      "loss: 0.534502  [147264/150000]\n",
      "Epoch 15\n",
      "-------------------------------\n",
      "loss: 0.544695  [   64/150000]\n",
      "loss: 0.573055  [ 6464/150000]\n",
      "loss: 0.550669  [12864/150000]\n",
      "loss: 0.723938  [19264/150000]\n",
      "loss: 0.584465  [25664/150000]\n",
      "loss: 0.569426  [32064/150000]\n",
      "loss: 0.607056  [38464/150000]\n",
      "loss: 0.616925  [44864/150000]\n",
      "loss: 0.642344  [51264/150000]\n",
      "loss: 0.549334  [57664/150000]\n",
      "loss: 0.587071  [64064/150000]\n",
      "loss: 0.615041  [70464/150000]\n",
      "loss: 0.579657  [76864/150000]\n",
      "loss: 0.601171  [83264/150000]\n",
      "loss: 0.612851  [89664/150000]\n",
      "loss: 0.618960  [96064/150000]\n",
      "loss: 0.601881  [102464/150000]\n",
      "loss: 0.640874  [108864/150000]\n",
      "loss: 0.504944  [115264/150000]\n",
      "loss: 0.580753  [121664/150000]\n",
      "loss: 0.576069  [128064/150000]\n",
      "loss: 0.545511  [134464/150000]\n",
      "loss: 0.611074  [140864/150000]\n",
      "loss: 0.621706  [147264/150000]\n",
      "Epoch 16\n",
      "-------------------------------\n",
      "loss: 0.524734  [   64/150000]\n",
      "loss: 0.669807  [ 6464/150000]\n",
      "loss: 0.589392  [12864/150000]\n",
      "loss: 0.582379  [19264/150000]\n",
      "loss: 0.664578  [25664/150000]\n",
      "loss: 0.615774  [32064/150000]\n",
      "loss: 0.622845  [38464/150000]\n",
      "loss: 0.538043  [44864/150000]\n",
      "loss: 0.562914  [51264/150000]\n",
      "loss: 0.582615  [57664/150000]\n",
      "loss: 0.512658  [64064/150000]\n",
      "loss: 0.580139  [70464/150000]\n",
      "loss: 0.594151  [76864/150000]\n",
      "loss: 0.647114  [83264/150000]\n",
      "loss: 0.592874  [89664/150000]\n",
      "loss: 0.643142  [96064/150000]\n",
      "loss: 0.584972  [102464/150000]\n",
      "loss: 0.561617  [108864/150000]\n",
      "loss: 0.597275  [115264/150000]\n",
      "loss: 0.599875  [121664/150000]\n",
      "loss: 0.532628  [128064/150000]\n",
      "loss: 0.521613  [134464/150000]\n",
      "loss: 0.664630  [140864/150000]\n",
      "loss: 0.608385  [147264/150000]\n",
      "Epoch 17\n",
      "-------------------------------\n",
      "loss: 0.621746  [   64/150000]\n",
      "loss: 0.511989  [ 6464/150000]\n",
      "loss: 0.563653  [12864/150000]\n",
      "loss: 0.562508  [19264/150000]\n",
      "loss: 0.498100  [25664/150000]\n",
      "loss: 0.594418  [32064/150000]\n",
      "loss: 0.583726  [38464/150000]\n",
      "loss: 0.588750  [44864/150000]\n",
      "loss: 0.521480  [51264/150000]\n",
      "loss: 0.579815  [57664/150000]\n",
      "loss: 0.542970  [64064/150000]\n",
      "loss: 0.544266  [70464/150000]\n",
      "loss: 0.571421  [76864/150000]\n",
      "loss: 0.526584  [83264/150000]\n",
      "loss: 0.591572  [89664/150000]\n",
      "loss: 0.611933  [96064/150000]\n",
      "loss: 0.562910  [102464/150000]\n",
      "loss: 0.613617  [108864/150000]\n",
      "loss: 0.639856  [115264/150000]\n",
      "loss: 0.555695  [121664/150000]\n",
      "loss: 0.595588  [128064/150000]\n",
      "loss: 0.634104  [134464/150000]\n",
      "loss: 0.585836  [140864/150000]\n",
      "loss: 0.635098  [147264/150000]\n",
      "Epoch 18\n",
      "-------------------------------\n",
      "loss: 0.599146  [   64/150000]\n",
      "loss: 0.618983  [ 6464/150000]\n",
      "loss: 0.562647  [12864/150000]\n",
      "loss: 0.629981  [19264/150000]\n",
      "loss: 0.557569  [25664/150000]\n",
      "loss: 0.644682  [32064/150000]\n",
      "loss: 0.552329  [38464/150000]\n",
      "loss: 0.632133  [44864/150000]\n",
      "loss: 0.667953  [51264/150000]\n",
      "loss: 0.579287  [57664/150000]\n",
      "loss: 0.551793  [64064/150000]\n",
      "loss: 0.516235  [70464/150000]\n",
      "loss: 0.623689  [76864/150000]\n",
      "loss: 0.583543  [83264/150000]\n",
      "loss: 0.602581  [89664/150000]\n",
      "loss: 0.617562  [96064/150000]\n",
      "loss: 0.589883  [102464/150000]\n",
      "loss: 0.594558  [108864/150000]\n",
      "loss: 0.591930  [115264/150000]\n",
      "loss: 0.598964  [121664/150000]\n",
      "loss: 0.631221  [128064/150000]\n",
      "loss: 0.569307  [134464/150000]\n",
      "loss: 0.626202  [140864/150000]\n",
      "loss: 0.579201  [147264/150000]\n",
      "Epoch 19\n",
      "-------------------------------\n",
      "loss: 0.588934  [   64/150000]\n",
      "loss: 0.588312  [ 6464/150000]\n",
      "loss: 0.513917  [12864/150000]\n",
      "loss: 0.606465  [19264/150000]\n",
      "loss: 0.592342  [25664/150000]\n",
      "loss: 0.615711  [32064/150000]\n",
      "loss: 0.557372  [38464/150000]\n",
      "loss: 0.559006  [44864/150000]\n",
      "loss: 0.585443  [51264/150000]\n",
      "loss: 0.485954  [57664/150000]\n",
      "loss: 0.589985  [64064/150000]\n",
      "loss: 0.587740  [70464/150000]\n",
      "loss: 0.629844  [76864/150000]\n",
      "loss: 0.521275  [83264/150000]\n",
      "loss: 0.603965  [89664/150000]\n",
      "loss: 0.644880  [96064/150000]\n",
      "loss: 0.630848  [102464/150000]\n",
      "loss: 0.571996  [108864/150000]\n",
      "loss: 0.580071  [115264/150000]\n",
      "loss: 0.593456  [121664/150000]\n",
      "loss: 0.551542  [128064/150000]\n",
      "loss: 0.611395  [134464/150000]\n",
      "loss: 0.577660  [140864/150000]\n",
      "loss: 0.612326  [147264/150000]\n",
      "Epoch 20\n",
      "-------------------------------\n",
      "loss: 0.627417  [   64/150000]\n",
      "loss: 0.637460  [ 6464/150000]\n",
      "loss: 0.666766  [12864/150000]\n",
      "loss: 0.531427  [19264/150000]\n",
      "loss: 0.562119  [25664/150000]\n",
      "loss: 0.587002  [32064/150000]\n",
      "loss: 0.652636  [38464/150000]\n",
      "loss: 0.559747  [44864/150000]\n",
      "loss: 0.587982  [51264/150000]\n",
      "loss: 0.601106  [57664/150000]\n",
      "loss: 0.565837  [64064/150000]\n",
      "loss: 0.607227  [70464/150000]\n",
      "loss: 0.626141  [76864/150000]\n",
      "loss: 0.640508  [83264/150000]\n",
      "loss: 0.605786  [89664/150000]\n",
      "loss: 0.608159  [96064/150000]\n",
      "loss: 0.530527  [102464/150000]\n",
      "loss: 0.587791  [108864/150000]\n",
      "loss: 0.593820  [115264/150000]\n",
      "loss: 0.615074  [121664/150000]\n",
      "loss: 0.616388  [128064/150000]\n",
      "loss: 0.569421  [134464/150000]\n",
      "loss: 0.607489  [140864/150000]\n",
      "loss: 0.645023  [147264/150000]\n",
      "Epoch 21\n",
      "-------------------------------\n",
      "loss: 0.603664  [   64/150000]\n",
      "loss: 0.614986  [ 6464/150000]\n",
      "loss: 0.521039  [12864/150000]\n",
      "loss: 0.547322  [19264/150000]\n",
      "loss: 0.577473  [25664/150000]\n",
      "loss: 0.597063  [32064/150000]\n",
      "loss: 0.585668  [38464/150000]\n",
      "loss: 0.541477  [44864/150000]\n",
      "loss: 0.588172  [51264/150000]\n",
      "loss: 0.645711  [57664/150000]\n",
      "loss: 0.617174  [64064/150000]\n",
      "loss: 0.659786  [70464/150000]\n",
      "loss: 0.518750  [76864/150000]\n",
      "loss: 0.564853  [83264/150000]\n",
      "loss: 0.540514  [89664/150000]\n",
      "loss: 0.677545  [96064/150000]\n",
      "loss: 0.583586  [102464/150000]\n",
      "loss: 0.587650  [108864/150000]\n",
      "loss: 0.618852  [115264/150000]\n",
      "loss: 0.546123  [121664/150000]\n",
      "loss: 0.611922  [128064/150000]\n",
      "loss: 0.769938  [134464/150000]\n",
      "loss: 0.654696  [140864/150000]\n",
      "loss: 0.555877  [147264/150000]\n",
      "Epoch 22\n",
      "-------------------------------\n",
      "loss: 0.546781  [   64/150000]\n",
      "loss: 0.663820  [ 6464/150000]\n",
      "loss: 0.586681  [12864/150000]\n",
      "loss: 0.545602  [19264/150000]\n",
      "loss: 0.622281  [25664/150000]\n",
      "loss: 0.613386  [32064/150000]\n",
      "loss: 0.679810  [38464/150000]\n",
      "loss: 0.561913  [44864/150000]\n",
      "loss: 0.588024  [51264/150000]\n",
      "loss: 0.646534  [57664/150000]\n",
      "loss: 0.567236  [64064/150000]\n",
      "loss: 0.582543  [70464/150000]\n",
      "loss: 0.581364  [76864/150000]\n",
      "loss: 0.589139  [83264/150000]\n",
      "loss: 0.651509  [89664/150000]\n",
      "loss: 0.562445  [96064/150000]\n",
      "loss: 0.679964  [102464/150000]\n",
      "loss: 0.560593  [108864/150000]\n",
      "loss: 0.535673  [115264/150000]\n",
      "loss: 0.605118  [121664/150000]\n",
      "loss: 0.591122  [128064/150000]\n",
      "loss: 0.576950  [134464/150000]\n",
      "loss: 0.554719  [140864/150000]\n",
      "loss: 0.568256  [147264/150000]\n",
      "Epoch 23\n",
      "-------------------------------\n",
      "loss: 0.515077  [   64/150000]\n",
      "loss: 0.553921  [ 6464/150000]\n",
      "loss: 0.577074  [12864/150000]\n",
      "loss: 0.585402  [19264/150000]\n",
      "loss: 0.611172  [25664/150000]\n",
      "loss: 0.587801  [32064/150000]\n",
      "loss: 0.588383  [38464/150000]\n",
      "loss: 0.711319  [44864/150000]\n",
      "loss: 0.527017  [51264/150000]\n",
      "loss: 0.678208  [57664/150000]\n",
      "loss: 0.589850  [64064/150000]\n",
      "loss: 0.584371  [70464/150000]\n",
      "loss: 0.581527  [76864/150000]\n",
      "loss: 0.653242  [83264/150000]\n",
      "loss: 0.613740  [89664/150000]\n",
      "loss: 0.525126  [96064/150000]\n",
      "loss: 0.609433  [102464/150000]\n",
      "loss: 0.608985  [108864/150000]\n",
      "loss: 0.682752  [115264/150000]\n",
      "loss: 0.592517  [121664/150000]\n",
      "loss: 0.597739  [128064/150000]\n",
      "loss: 0.542855  [134464/150000]\n",
      "loss: 0.600933  [140864/150000]\n",
      "loss: 0.548048  [147264/150000]\n",
      "Epoch 24\n",
      "-------------------------------\n",
      "loss: 0.538409  [   64/150000]\n",
      "loss: 0.587732  [ 6464/150000]\n",
      "loss: 0.553931  [12864/150000]\n",
      "loss: 0.615992  [19264/150000]\n",
      "loss: 0.546581  [25664/150000]\n",
      "loss: 0.550231  [32064/150000]\n",
      "loss: 0.648659  [38464/150000]\n",
      "loss: 0.633638  [44864/150000]\n",
      "loss: 0.550732  [51264/150000]\n",
      "loss: 0.543928  [57664/150000]\n",
      "loss: 0.617104  [64064/150000]\n",
      "loss: 0.560582  [70464/150000]\n",
      "loss: 0.669584  [76864/150000]\n",
      "loss: 0.601878  [83264/150000]\n",
      "loss: 0.527612  [89664/150000]\n",
      "loss: 0.572111  [96064/150000]\n",
      "loss: 0.598589  [102464/150000]\n",
      "loss: 0.603640  [108864/150000]\n",
      "loss: 0.608930  [115264/150000]\n",
      "loss: 0.534105  [121664/150000]\n",
      "loss: 0.629878  [128064/150000]\n",
      "loss: 0.611613  [134464/150000]\n",
      "loss: 0.513288  [140864/150000]\n",
      "loss: 0.635841  [147264/150000]\n",
      "Epoch 25\n",
      "-------------------------------\n",
      "loss: 0.557434  [   64/150000]\n",
      "loss: 0.615803  [ 6464/150000]\n",
      "loss: 0.579832  [12864/150000]\n",
      "loss: 0.575765  [19264/150000]\n",
      "loss: 0.606995  [25664/150000]\n",
      "loss: 0.632700  [32064/150000]\n",
      "loss: 0.632381  [38464/150000]\n",
      "loss: 0.615837  [44864/150000]\n",
      "loss: 0.587369  [51264/150000]\n",
      "loss: 0.600212  [57664/150000]\n",
      "loss: 0.616872  [64064/150000]\n",
      "loss: 0.609934  [70464/150000]\n",
      "loss: 0.662792  [76864/150000]\n",
      "loss: 0.589709  [83264/150000]\n",
      "loss: 0.618774  [89664/150000]\n",
      "loss: 0.645855  [96064/150000]\n",
      "loss: 0.551524  [102464/150000]\n",
      "loss: 0.533778  [108864/150000]\n",
      "loss: 0.577648  [115264/150000]\n",
      "loss: 0.544056  [121664/150000]\n",
      "loss: 0.614962  [128064/150000]\n",
      "loss: 0.539947  [134464/150000]\n",
      "loss: 0.612202  [140864/150000]\n",
      "loss: 0.560787  [147264/150000]\n",
      "Epoch 26\n",
      "-------------------------------\n",
      "loss: 0.562039  [   64/150000]\n",
      "loss: 0.557390  [ 6464/150000]\n",
      "loss: 0.611007  [12864/150000]\n",
      "loss: 0.615194  [19264/150000]\n",
      "loss: 0.574516  [25664/150000]\n",
      "loss: 0.613461  [32064/150000]\n",
      "loss: 0.606872  [38464/150000]\n",
      "loss: 0.565459  [44864/150000]\n",
      "loss: 0.585175  [51264/150000]\n",
      "loss: 0.625837  [57664/150000]\n",
      "loss: 0.561817  [64064/150000]\n",
      "loss: 0.596336  [70464/150000]\n",
      "loss: 0.600057  [76864/150000]\n",
      "loss: 0.573806  [83264/150000]\n",
      "loss: 0.641340  [89664/150000]\n",
      "loss: 0.608286  [96064/150000]\n",
      "loss: 0.649721  [102464/150000]\n",
      "loss: 0.671843  [108864/150000]\n",
      "loss: 0.564814  [115264/150000]\n",
      "loss: 0.632006  [121664/150000]\n",
      "loss: 0.537901  [128064/150000]\n",
      "loss: 0.519129  [134464/150000]\n",
      "loss: 0.604893  [140864/150000]\n",
      "loss: 0.565878  [147264/150000]\n",
      "Epoch 27\n",
      "-------------------------------\n",
      "loss: 0.637899  [   64/150000]\n",
      "loss: 0.619308  [ 6464/150000]\n",
      "loss: 0.565980  [12864/150000]\n",
      "loss: 0.580520  [19264/150000]\n",
      "loss: 0.517601  [25664/150000]\n",
      "loss: 0.600046  [32064/150000]\n",
      "loss: 0.522602  [38464/150000]\n",
      "loss: 0.585224  [44864/150000]\n",
      "loss: 0.616364  [51264/150000]\n",
      "loss: 0.607755  [57664/150000]\n",
      "loss: 0.672700  [64064/150000]\n",
      "loss: 0.615942  [70464/150000]\n",
      "loss: 0.563604  [76864/150000]\n",
      "loss: 0.516980  [83264/150000]\n",
      "loss: 0.568126  [89664/150000]\n",
      "loss: 0.643335  [96064/150000]\n",
      "loss: 0.598452  [102464/150000]\n",
      "loss: 0.645684  [108864/150000]\n",
      "loss: 0.615591  [115264/150000]\n",
      "loss: 0.621618  [121664/150000]\n",
      "loss: 0.534854  [128064/150000]\n",
      "loss: 0.589854  [134464/150000]\n",
      "loss: 0.635436  [140864/150000]\n",
      "loss: 0.554766  [147264/150000]\n",
      "Epoch 28\n",
      "-------------------------------\n",
      "loss: 0.521061  [   64/150000]\n",
      "loss: 0.636044  [ 6464/150000]\n",
      "loss: 0.637459  [12864/150000]\n",
      "loss: 0.643848  [19264/150000]\n",
      "loss: 0.644321  [25664/150000]\n",
      "loss: 0.495741  [32064/150000]\n",
      "loss: 0.607640  [38464/150000]\n",
      "loss: 0.588970  [44864/150000]\n",
      "loss: 0.559616  [51264/150000]\n",
      "loss: 0.594282  [57664/150000]\n",
      "loss: 0.576745  [64064/150000]\n",
      "loss: 0.550133  [70464/150000]\n",
      "loss: 0.598407  [76864/150000]\n",
      "loss: 0.547898  [83264/150000]\n",
      "loss: 0.620289  [89664/150000]\n",
      "loss: 0.625945  [96064/150000]\n",
      "loss: 0.588795  [102464/150000]\n",
      "loss: 0.564310  [108864/150000]\n",
      "loss: 0.471800  [115264/150000]\n",
      "loss: 0.661256  [121664/150000]\n",
      "loss: 0.512393  [128064/150000]\n",
      "loss: 0.518628  [134464/150000]\n",
      "loss: 0.637960  [140864/150000]\n",
      "loss: 0.634269  [147264/150000]\n",
      "Epoch 29\n",
      "-------------------------------\n",
      "loss: 0.532436  [   64/150000]\n",
      "loss: 0.582553  [ 6464/150000]\n",
      "loss: 0.599198  [12864/150000]\n",
      "loss: 0.599311  [19264/150000]\n",
      "loss: 0.537017  [25664/150000]\n",
      "loss: 0.545210  [32064/150000]\n",
      "loss: 0.594809  [38464/150000]\n",
      "loss: 0.560428  [44864/150000]\n",
      "loss: 0.585562  [51264/150000]\n",
      "loss: 0.615488  [57664/150000]\n",
      "loss: 0.601441  [64064/150000]\n",
      "loss: 0.586801  [70464/150000]\n",
      "loss: 0.612262  [76864/150000]\n",
      "loss: 0.569715  [83264/150000]\n",
      "loss: 0.617154  [89664/150000]\n",
      "loss: 0.615478  [96064/150000]\n",
      "loss: 0.627287  [102464/150000]\n",
      "loss: 0.594675  [108864/150000]\n",
      "loss: 0.527844  [115264/150000]\n",
      "loss: 0.601432  [121664/150000]\n",
      "loss: 0.560899  [128064/150000]\n",
      "loss: 0.590129  [134464/150000]\n",
      "loss: 0.634971  [140864/150000]\n",
      "loss: 0.612034  [147264/150000]\n",
      "Epoch 30\n",
      "-------------------------------\n",
      "loss: 0.663320  [   64/150000]\n",
      "loss: 0.590303  [ 6464/150000]\n",
      "loss: 0.548620  [12864/150000]\n",
      "loss: 0.606458  [19264/150000]\n",
      "loss: 0.551339  [25664/150000]\n",
      "loss: 0.597070  [32064/150000]\n",
      "loss: 0.568118  [38464/150000]\n",
      "loss: 0.656885  [44864/150000]\n",
      "loss: 0.517651  [51264/150000]\n",
      "loss: 0.619143  [57664/150000]\n",
      "loss: 0.576512  [64064/150000]\n",
      "loss: 0.551789  [70464/150000]\n",
      "loss: 0.639050  [76864/150000]\n",
      "loss: 0.568361  [83264/150000]\n",
      "loss: 0.549884  [89664/150000]\n",
      "loss: 0.634776  [96064/150000]\n",
      "loss: 0.555386  [102464/150000]\n",
      "loss: 0.525088  [108864/150000]\n",
      "loss: 0.613388  [115264/150000]\n",
      "loss: 0.529327  [121664/150000]\n",
      "loss: 0.554036  [128064/150000]\n",
      "loss: 0.570808  [134464/150000]\n",
      "loss: 0.573866  [140864/150000]\n",
      "loss: 0.593912  [147264/150000]\n",
      "Done!\n",
      "Train data:\n",
      "AUC value is: 0.6999572362931573\n",
      "Accuracy is: 0.6615666666666666\n",
      "Test data:\n",
      "AUC value is: 0.7051648361054951\n",
      "Accuracy is: 0.66606\n"
     ]
    }
   ],
   "source": [
    "attack_model = shadow_attack(sha_models=sha_models, tar_model=tar_model, model_num=num_shadowsets, weight_dir=weight_dir, data_name=DATA_NAME, model=MODEL, model_transform=model_transform, \n",
    "                  model_epochs=EPOCHS, batch_size=BATCH_SIZE, learning_rate=attack_lr, attack_epochs=30, attack_transform=attack_transform, \n",
    "                  device=device, prop_keep=0.5, top_k=None, attack_class=attack_class)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "ab8c2490-c230-4516-a19a-df84a3779011",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Attack_NN(\n",
       "  (linear_relu_stack): Sequential(\n",
       "    (0): Linear(in_features=11, out_features=128, bias=True)\n",
       "    (1): ReLU()\n",
       "    (2): Linear(in_features=128, out_features=64, bias=True)\n",
       "    (3): ReLU()\n",
       "    (4): Linear(in_features=64, out_features=1, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "attack_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "d1f87436-bae2-410f-a025-e9757ecd17d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "targetX = conf_data_all[tar_model].astype(np.float32)\n",
    "targetY = train_keep[tar_model]\n",
    "targetX = np.concatenate((targetX, Y_data.reshape(Y_data.shape[0],1)), 1)\n",
    "targetX = targetX.astype(np.float32)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "ebd65e38-9469-4c5f-a5a3-041b675a9d3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "top_k = None\n",
    "if top_k:\n",
    "    # 仅使用概率向量的前3个值\n",
    "    targetX, _ = get_top_k_conf(top_k, targetX, targetX)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "8ba8d2b7-7d6c-4e17-95bd-f854e8af5af9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "AUC value is: 0.7051648425054993\n",
      "Accuracy is: 0.66606\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.66606"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "shadow_attack_data = CustomDataset(targetX, targetY, attack_transform)\n",
    "shadow_attack_dataloader = DataLoader(shadow_attack_data, batch_size=batch_size, shuffle=True)\n",
    "attack_test_scores, attack_test_mem = get_attack_pred(shadow_attack_dataloader, attack_model, device)\n",
    "attack_test_scores, attack_test_mem = attack_test_scores.detach().cpu().numpy(), attack_test_mem.detach().cpu().numpy()\n",
    "evaluate_ROC(attack_test_scores, attack_test_mem)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "c8cb250e-ad72-40b3-8967-46b8538be791",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "AUC value is: 0.6892114804089503\n",
      "Accuracy is: 0.6596\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.6596"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "evaluate_ROC(attack_test_scores[risk_rank[:2500]], attack_test_mem[risk_rank[:2500]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eebaa471-2271-4310-b936-1030edb6f096",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4bbfa8f-bd1a-46f6-9b79-a7c758d30f27",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b077751a-c0ee-4e39-9f70-41b791d348e6",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f06a3e24-5a57-472a-b6f0-78961403f471",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40ceb18e-621c-4c72-8ef7-94498c7b65c5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30abf80e-2b00-4b70-ac0c-a7099f987cc4",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f0ae045-14bd-4159-948c-8881f74cd181",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d64311d6-c09d-47b3-8076-f11eb10134aa",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87d308e9-be9d-40fb-be9e-e84efc339052",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0308f234-469f-4670-b4e4-51973cb27ce4",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "a9b3ad57-7f95-4bc7-8057-05aa0ca4e10a",
   "metadata": {},
   "source": [
    "### 绘制攻击成功率随风险变化曲线"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "03e43bb9-0680-487b-8433-805cb202031f",
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_all = conf_data_all.argmax(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "b7dda5b4-e79d-40b7-beb8-86d21143ac11",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_att = (pred_all == Y_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "a79890e7-c8ea-4ddd-9e0a-c51831260ad9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(100, 50000)"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "base_att.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "18020329-0447-496c-bffa-a59f865df12c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(100, 50000)"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_keep.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "67ad94bc-8667-4f2b-ab7b-3d32a98e97d8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(50000,)"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pri_risk_all.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "973858e9-c4fd-47cd-9941-c5730dfff80a",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_axi = []\n",
    "Y_axi = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "4d44c0fa-c7b9-4f39-a245-0f8e3671141c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(50000,)"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pri_risk_rank.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "030de03c-e397-4141-8f84-6cd9910e91a7",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# 输入loss_data_all，pri_risk_rank，train_keep，conf_data_all，pri_risk_all\n",
    "\n",
    "for i in range(100): # 对数据下标进行计数\n",
    "    start = i * 500\n",
    "    end = (i+1) * 500\n",
    "    risk_t = pri_risk_all[pri_risk_rank[start:end]]\n",
    "    for j in range(100): # 对目标模型进行计数\n",
    "        pred_temp = base_att[j][pri_risk_rank[start:end]]\n",
    "        mem_temp = train_keep[j][pri_risk_rank[start:end]]\n",
    "        if j==0:\n",
    "            pred_t = pred_temp\n",
    "            mem_t = mem_temp\n",
    "        else:\n",
    "            pred_t = np.concatenate((pred_t, pred_temp), 0)\n",
    "            mem_t = np.concatenate((mem_t, mem_temp), 0)\n",
    "    acc = metrics.accuracy_score(mem_t, pred_t)\n",
    "    risk = np.mean(risk_t)\n",
    "    X_axi.append(risk)\n",
    "    Y_axi.append(acc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4df62ec-4eb6-4c8f-a522-737a9e0d0e06",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "238bf52f-5f07-4176-9642-bd34be61a5a6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAj8AAAGwCAYAAABGogSnAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/xnp5ZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABIJUlEQVR4nO3dd3xUVf7/8fekF1KAQAIhBELvYJBQBBUjiC6KuIqIgii4q+CicVVYBdTdn6zriqjLgoViF1Fsi18Eg4hAaKFXadITSCCF9Mzc3x+BgZhCBiaZmczr+XjMA3Ln3Dufy81435577rkmwzAMAQAAuAkPRxcAAABQkwg/AADArRB+AACAWyH8AAAAt0L4AQAAboXwAwAA3ArhBwAAuBUvRxdQ0ywWi06cOKGgoCCZTCZHlwMAAKrAMAxlZ2ercePG8vC4ur4btws/J06cUFRUlKPLAAAAV+Do0aNq0qTJVW3D7cJPUFCQpJJ/vODgYAdXAwAAqiIrK0tRUVHW8/jVcLvwc+FSV3BwMOEHAAAXY48hKwx4BgAAboXwAwAA3ArhBwAAuBXCDwAAcCuEHwAA4FYIPwAAwK0QfgAAgFsh/AAAALdC+AEAAG6F8AMAANwK4QcAALgVwg8AAHArbvdgUwAAYF+GYaig2KK8QrPyiswymaRGIf6OLqtChB8AAGq5wmKL8orMyi8yK7fQbA0p1j+LzMorLD7/s6Xk7xfa/m69/PKWF5llGBc/r0ezevr8z70ct8OXQfgBAMCBzBbjfNAoVn6hRblFxWXDSSV/5haZlX/+59zz4cT69/PLiy3G5QuxEx9PD5lMNfZxV4TwAwBABSwWQ/nF50PG74JF3vnQUV4vSF7h73tVKv6z0Gypsf3x9DApwNtTfj6e8vf2VICPp/y8L/n7Jcv9vUveC/DxlL/PJX8/397//HLr388v9/J0/uHEhB8AgEv6/TiTMj0iv+8FuWS59TJPUenLPL8PLflFNRdMTCaVDhaV/GkNJz6eCvC+NJx4nW/nIX9vrzLreXuaZHL2bpkaQPgBAFSL8seZFCuv0HLxMk8Fl26uZJxJdfPz9qig18NL/hfeuyR8BPh4le5VOb9emV6V88t9vTwIJjWE8AMAuGIZuYU6cPqcDpzKKfnz9DkdOJ2jY2dzVWSu2XEmft4lgaPCXpPfhRabLvl4ecrDg2BSWxB+AACVMlsMncjI0/7T53Tg1LlSYSc9p/Cy6186zqS8Xo9LL91U5ZKPv4+nAry95OdzsSfGFcaZwHkQfgAAkqTcwmIdPJ1j7b05cD7sHErLUUFxxWNfGof4qUXDOmrRoI5aNAhUiwZ1FB0WqDq+JZeAfLwIJnAuhB8AcCOGYej0uYIyl6kOnDqn4xl5Fa7n4+WhmLDAiwHnfNhpHhaoQF9OJXAt/MYCQC1UZLboyJlcHTh17vzlqothJzu/uML16gX6WHtvWjSooxYNS/7epG6APBnzglqC8AMALiwrv+j8OJyLl6kOnD6nw+m5FU5s52GSouoFlLpM1bJhHcU0qKN6gT41vAdAzSP8AICTs1gMnczKvzjY+HxPzv7T53Q6u6DC9QJ8PEsFnAuXqqLrB8jP27MG9wBwLoQfAHASFouhg2k52pOSVeoy1cHTOcorMle4Xniw78XLVJeMx4kI9uP2bKAchB8AcADDMHQ4PVfbjmdq+7EMbTuWqR3HM5VTWH7I8fIwqVlYYKnLVC0a1FFMg0AF+XnXcPWAayP8AEA1MwxDxzPytP1YprYdz9S2YxnafixTWeUMPPbz9lDbiGC1alin1O3jUfUC5M1cNoBdEH4AwM5Ss/K19WiGth/P1LZjmdp+PFNnypkM0MfLQ+0aBatLkxB1igxR5yahatEgkAn7gGpG+AGAq5B2rqCkR+dYprYfL7l8daqcQcheHia1bRSkTpGh6nw+7LQOD2ICQMABCD8AUEUZuYUXe3OOlVy+OpGZX6adh0lqHR5U0psTFarOkSFqExHEHVaAkyD8AEA5svOLtON4lrYfz9DW82HnyJncMu1MJqlFgzrqHBmiTk1C1LlJiNo3CpG/D0EHcFaEHwBuL7ewWDtPZJ3v0cnQtuOZOng6p9y2zeoHqFOTUGvY6dA4mLutABdD+AHgdorMFm09mqFV+9O0Zn+6Nh05W+5syJGh/urcpGQgcucmIerYOEQhAQQdwNURfgDUeoZhaG9qtlbvT9fq/WladzC9zHw64cG+JSHnfI9Op8gQ1a/j66CKAVQnwg+AWul4Rp5W70sr6d05kK60c6XvwKob4K3eLcN0Xcsw9WkRpqb1AxxUKYCaRvgBUCuczSlU0sGSnp3V+9P0W3rpwcn+3p7q0bye+rSsrz4tw9QuIphHPwBuivADwCXlF5m14bcz1nE7O05kyrhk2I6nh0ldmoToupZh6t0yTN2ahsrXizuwABB+ALiIYrNF249nas2BdK3al6bkw2dVaLaUatM6vI56tyi5lBUXU4+7sACUi/ADwCkZhqEDp3O0en/JuJ21B9OV/btnYTUK8VOflmEll7JahKlhsJ+DqgXgSgg/AJzGycw8rd6frjXnBymnZJWePTnYz0u9WtS3XsqKCQuUycS4HQC2IfwAcJhLByknHUjXwbTSEwv6eHno2mZ1rZeyOkaGyJNBygCuEuEHQI3JLSzW+kNntOZASeDZdTKr1CBlD5PUqUmo+rQouSMrNrouz8MCYHeEHwDVprDYoi1HM7R6f5rWHEjTlqMZKjKXnkn5wiDl3i3qKy6mvkL8GaQMoHoRfgDYjcViaNfJLK05kKbV+9O14bczyv3dTMqRof7WuXZ6taivhkEMUgZQswg/AK5KZm6Rvtt2omTczsF0ZeQWlXq/fqCPerWor94tSu7KalovgEHKAByK8APgiuQVmjV/zW+atWK/si65BT3Qx1NxMfXV+/y4nTbhQcykDMCpEH4A2KTYbNHC5GOa8eOvSs0qeV5W6/A6Gty5sXq3DFPnJiHy9vRwcJUAUDHCD4AqMQxD/7cjRf/+Ya/1lvTIUH89NaC17ugayS3oAFwG4QfAZa3Zn6ZXluzR1mOZkqR6gT4af2NLjejZlOdlAXA5hB8AFdpxPFOvLNmjX/alSSoZzzOmb4zG9otRHV/+8wHANfFfLwBl/JaWo38v3av/bTspSfL2NGlEXLTG92+psDq+Dq4OAK4O4QeA1ansfL2ZuE+frT+qYoshk0ka0jVST8a3VtP6AY4uDwDswuG3ZMycOVPNmjWTn5+f4uLitH79+grbFhUV6aWXXlKLFi3k5+enLl26aMmSJTVYLVA7ZeUX6d8/7NX1/1qhj9YeUbHF0A1tGmjx4331+rCuBB8AtYpDe34WLFighIQEzZ49W3FxcZoxY4YGDhyovXv3qmHDhmXaP//88/roo4/07rvvqm3btvrhhx905513as2aNerWrZsD9gBwbbmFxfp47RHNXLHfOjlht6ahevaWtuoZU9/B1QFA9TAZhmFcvln1iIuL07XXXqv//Oc/kiSLxaKoqCg9/vjjmjhxYpn2jRs31nPPPadx48ZZl911113y9/fXRx99VKXPzMrKUkhIiDIzMxUcHGyfHQFczOnsAn2Q9Js+XHvYGnpaNqyjpwe20YD24czADMDp2PP87bCen8LCQiUnJ2vSpEnWZR4eHoqPj1dSUlK56xQUFMjPr/RzgPz9/bVq1aoKP6egoEAFBQXWn7Oysq6ycsB17T+Vrfd+OaRFm4+rsNgiSYquH6BxN7TU0Gsi5cXkhADcgMPCT1pamsxms8LDw0stDw8P1549e8pdZ+DAgZo+fbr69eunFi1aKDExUYsWLZLZbC63vSRNmzZNL774ol1rB1yJYRhaf+iM3ll5UIl7TlmXd2saqj/1i9HN7SOYoBCAW3Gpu73eeOMNjR07Vm3btpXJZFKLFi00evRozZ07t8J1Jk2apISEBOvPWVlZioqKqolyAYcqNlu0ZGeK3l150Do5ockkDWgfrkf6xSg2up6DKwQAx3BY+AkLC5Onp6dSU1NLLU9NTVVERES56zRo0EBff/218vPzlZ6ersaNG2vixImKiYmp8HN8fX3l68u8JHAfOQXF+nzjUc1ZdUjHzuZJkny9PPTH2CZ6+LrmimlQx8EVAoBjOSz8+Pj4KDY2VomJiRoyZIikkgHPiYmJGj9+fKXr+vn5KTIyUkVFRfryyy91zz331EDFgHM7lZWv+Wt+00drD1ufsl4v0Ecje0XrgZ7Rqs/khAAgycGXvRISEjRq1Ch1795dPXr00IwZM5STk6PRo0dLkkaOHKnIyEhNmzZNkrRu3TodP35cXbt21fHjx/XCCy/IYrHomWeeceRuAA71a2q23vvloL7efEKF5pJBzM3DAjWmb3PddU0T+Xnz7C0AuJRDw8+wYcN0+vRpTZkyRSkpKeratauWLFliHQR95MgReXhcvPskPz9fzz//vA4ePKg6dero1ltv1YcffqjQ0FAH7QHgGIZh6Jd9aXpv1SGt/PW0dXn36Lp6pF+M4tuFy4NBzABQLofO8+MIzPMDV5ZfZNa3W07ovVUH9WvqOUmSh0ka2CFCY/rGKDa6roMrBIDqUSvm+QFQdennCvTR2iP6cO1vSjtXKKnkCevDrm2q0X2aKaoej58AgKoi/ABObP+pbM1ZdUhfbro4KWHjED+N7tNcw3pEKdjP28EVAoDrIfwATsYwDK3en673Vh3Uir0Xx/N0aRKiMX1jdEvHCHkzEzMAXDHCD+Akis0WfbX5uOasOqQ9KdmSLk5KOPb8eB6euQUAV4/wAziBjb+d0fNf77CGngAfT93TPUqj+zRTdP1AB1cHALUL4QdwoPRzBfrn/+3RwuRjkqTQAG/9qV8L3dejqUICGM8DANWB8AM4gMVi6LMNR/XKkj3KzCuSJN17bZSeuaWt6gX6OLg6AKjdCD9ADdtxPFPPfb1DW49mSJLaNwrW34d0ZI4eAKghhB+ghmTmFWn60r36cO1hWQwpyNdLCQNa64Ge0fLi7i0AqDGEH6CaGYahb7ac0D8W71bauQJJ0h1dG+u5W9upYbCfg6sDAPdD+AGq0b7UbE3+ZofWHjwjSYppEKh/3NFRvVuGObgyAHBfhB+gGuxNydYHSb9pwYajKrYY8vP20OP9W2ls3xj5eHGJCwAcifAD2ElhsUVLd6Xog6TDWn/ojHX5ze3DNeUP7Xn+FgA4CcIPcJVOZubp03VH9OmGozqdXTKmx9PDpJvbhWtk72j1bsElLgBwJoQf4AoYhqE1B9L1YdJhLdudKrPFkCQ1CPLV8B5NNbxHlBqF+Du4SgBAeQg/gA2y84v0RfIxfbj2sA6ezrEuj2teTw/0itaA9hGM6QEAJ0f4Aapo/aEzeuKzzTqRmS9JCvTx1NBrmuiBXtFqHR7k4OoAAFVF+AEuo9hs0ZvL9+s/y/fJYkhN6wVobL8Y3dktUnV8+QoBgKvhv9xAJY6dzdUTn23RxsNnJUl/jG2iF2/voEBCDwC4LP4LDlRg8baTmrhom7LzixXk66V/3NlRd3SNdHRZAICrRPgBfie3sFgvfrtLCzYelSR1axqqN+/txjw9AFBLEH6AS+w4nqm/fLZZB0/nyGSSxt3QUhPiW8mbB48CQK1B+AFUMm/P3NW/6ZX/26NCs0Xhwb56fVhXJigEgFqI8AO3dyanUE99vkU/7T0tqeRxFP+6q7PqBvo4uDIAQHUg/MCtbT2aocc+3qTjGXny8fLQ5Nva6f6e0TKZTI4uDQBQTQg/cEuGYeiT9Uf04re7VGi2qFn9AM26P1btGgU7ujQAQDUj/MDt5BWa9dzX27Vo03FJ0oD24fr3PV0U7Oft4MoAADWB8AO3cigtR49+lKw9KdnyMEnP3tJWj/SL4TIXALgRwg/cxg87U/TXz7cqu6BYYXV89Nbwa9SrRX1HlwUAqGGEH9R6xWaLXl26V2//fFCS1D26rmaOuEbhwX4OrgwA4AiEH9RqaecKNP6TTVp78Iwk6aE+zTXp1rZMWggAbozwg1rr4OlzGjVvvY6eyVOgj6de+WNn/aFzY0eXBQBwMMIPaqXkw2c15v0NOptbpKb1AjT3we5q2TDI0WUBAJwA4Qe1ztKdKXr8080qKLaoc5MQzRl1rRoE+Tq6LACAkyD8oFb5cO1hTf1mhyyG1L9tQ/3nvm4K8OHXHABwEWcF1AqGYehfP+zVrBUHJEnDe0Tp73d0lBcDmwEAv0P4gcsrLLbo2S+36avNJTM2J9zcWo/3b8nEhQCAchF+4NKy84v06EebtGp/mjw9TJo2tJPu6R7l6LIAAE6M8AOXlZqVr1Fz12tPSrYCfTz13/tjdX3rBo4uCwDg5Ag/cDmZeUVauPGo3ll5UKeyC9QgyFfzHrxWHSNDHF0aAMAFEH7gMvalZuv9pN/0ZfJx5RWZJUktGgRq/ugeiqoX4ODqAACugvADp2a2GPppzynNX/ObVu1Psy5vGxGkUb2baUjXSPn7eDqwQgCAqyH8wClduLT1QdJhHTmTK0nyMEk3tw/Xg72bq2dMPe7mAgBcEcIPnM7320/q6YVblVNYcmkrxN9b914bpft7RnN5CwBw1Qg/cCpzVx3S3xfvkmFIrcPraHSf5lzaAgDYFeEHTsFiMTTt/3br3V8OSZJG9orW1MEd5OnBpS0AgH0RfuBwBcVm/XXhNn239YQk6dlb2urP18cwpgcAUC0IP3CozLwi/enDjVp78Iy8PEx69e7OurNbE0eXBQCoxQg/cJiTmXl6cO4G7U3NVh1fL826/xr1bcUMzQCA6kX4gUP8mpqtUXPX62RmvhoG+Wre6GvVoTEzNAMAqh/hBzVu7cF0PfLBRmXlF6tFg0C9/1APNanLLewAgJpB+EGNWrjxqP721XYVmQ11j66r90Z1V2iAj6PLAgC4EcIPaoTFYujVpXs1a8UBSdJtnRrptXu6yM+b+XsAADWL8INql1tYrIQFW7VkZ4ok6fH+LfVkfGt5MIcPAMABCD+oVqlZ+Rrz/kZtP54pH08PvfLHTtzKDgBwKMIPqs2O45ka8/5GpWTlq16gj955IFbdm9VzdFkAADdH+EG1+GFnip74bIvyisxq1bCO5oy6Vk3rc0cXAMDxCD+wq6z8Ik1f+qveT/pNhiH1bRWmmSOuUbCft6NLAwBAEuEHdmIYhr7dekJ//99upZ0rkFTycNIpf2gvL08PB1cHAMBFhB9ctf2nsjX5651KOpguSYoJC9RLd3TUda3CHFwZAABlEX5wxXILi/XW8v1675eDKjIb8vXy0OP9W2psvxj5ejF/DwDAORF+cEV+2XdaE7/cruMZeZKkm9o21Au3d1BUPQY1AwCcG+EHNks+fFYPz9+oQrNFkaH+euH2Drq5fbijywIAoEocPhJ15syZatasmfz8/BQXF6f169dX2n7GjBlq06aN/P39FRUVpSeffFL5+fk1VC2OZ+TpTx+WBJ/4dg21LKEfwQcA4FJsDj+jRo3SypUr7fLhCxYsUEJCgqZOnapNmzapS5cuGjhwoE6dOlVu+08++UQTJ07U1KlTtXv3bs2ZM0cLFizQ3/72N7vUg8rlFhZr7PsblXauUG0jgvTGvd0U4EPnIQDAtdgcfjIzMxUfH69WrVrp5Zdf1vHjx6/4w6dPn66xY8dq9OjRat++vWbPnq2AgADNnTu33PZr1qxRnz59dN9996lZs2YaMGCAhg8fftneIlw9i8XQU59v1a6TWaof6KP3RnVXoC/BBwDgemwOP19//bWOHz+uRx99VAsWLFCzZs00aNAgffHFFyoqKqrydgoLC5WcnKz4+PiLxXh4KD4+XklJSeWu07t3byUnJ1vDzsGDB/X999/r1ltvrfBzCgoKlJWVVeoF272RuE//tyNF3p4mzX4gVk3qMrAZAOCarmjMT4MGDZSQkKCtW7dq3bp1atmypR544AE1btxYTz75pPbt23fZbaSlpclsNis8vPR4kfDwcKWkpJS7zn333aeXXnpJ1113nby9vdWiRQvdcMMNlV72mjZtmkJCQqyvqKgo23YW+t+2E3ojseSY/r8hnXQtz+cCALiwqxrwfPLkSS1btkzLli2Tp6enbr31Vm3fvl3t27fX66+/bq8arVasWKGXX35Z//3vf7Vp0yYtWrRIixcv1t///vcK15k0aZIyMzOtr6NHj9q9rtpsx/FM/XXhVknSmOua655rCY8AANdm86CNoqIiffvtt5o3b56WLl2qzp0764knntB9992n4OBgSdJXX32lhx56SE8++WSF2wkLC5Onp6dSU1NLLU9NTVVERES560yePFkPPPCAxowZI0nq1KmTcnJy9Mgjj+i5556Th0fZLOfr6ytfX19bdxOSTmXla+wHG5VfZNH1rRto0q3tHF0SAABXzebw06hRI1ksFutA465du5Zpc+ONNyo0NLTS7fj4+Cg2NlaJiYkaMmSIJMlisSgxMVHjx48vd53c3NwyAcfTs2QmYcMwbN0VVGLXiSz95bPNOpmZrxYNAvXWfd3k6WFydFkAAFw1m8PP66+/rrvvvlt+fn4VtgkNDdWhQ4cuu62EhASNGjVK3bt3V48ePTRjxgzl5ORo9OjRkqSRI0cqMjJS06ZNkyQNHjxY06dPV7du3RQXF6f9+/dr8uTJGjx4sDUE4epYLIbmrj6kfy3Zq0KzRQ2CfPXeqGt5KjsAoNawOfzcfvvtys3NLRN+zpw5Iy8vL+ulr6oYNmyYTp8+rSlTpiglJUVdu3bVkiVLrIOgjxw5Uqqn5/nnn5fJZNLzzz+v48ePq0GDBho8eLD+3//7f7buBsqRmpWvvy7cql/2pUmS4ts11Ct3dVb9Olw2BADUHibDxutFgwYN0uDBg/XYY4+VWj579mx9++23+v777+1aoL1lZWUpJCREmZmZNgW12m7JjhRNWrRNZ3OL5Oftocl/aK/7ejSVycSlLgCA49nz/G3z3V7r1q3TjTfeWGb5DTfcoHXr1l1VMah5uYXFmrRom/78UbLO5hapQ+Ng/e/xvhoRF03wAQDUSjZf9iooKFBxcXGZ5UVFRcrLy7NLUagZuYXFuvedtdp2LFMmk/RIvxg9dXMb+Xg5/JFvAABUG5vPcj169NA777xTZvns2bMVGxtrl6JQ/cwWQxM+26JtxzJVN8BbHz8cp0mD2hF8AAC1ns09P//4xz8UHx+vrVu36qabbpIkJSYmasOGDVq6dKndC0T1mPb9bi3blSofTw+9O7K7ujNrMwDATdj8v/l9+vRRUlKSoqKi9Pnnn+u7775Ty5YttW3bNvXt27c6aoSdfbT2sN5bVTIVwat3dyb4AADcis13e7k6d7/ba8XeU3r4/Y0yWww9dXNrPX5TK0eXBADAZdnz/G3zZa9L5efnq7CwsNQydwwUrmJPSpbGf7JZZouhu65povH9Wzq6JAAAapzNl71yc3M1fvx4NWzYUIGBgapbt26pF5zTqax8PTRvg84VFCuueT1NG9qJW9kBAG7J5vDz9NNPa/ny5Zo1a5Z8fX313nvv6cUXX1Tjxo31wQcfVEeNuEoZuYV6+P2NOpGZr5iwQL39QCx3dQEA3JbNl72+++47ffDBB7rhhhs0evRo9e3bVy1btlR0dLQ+/vhjjRgxojrqxBU6djZXo+au14HTOaob4K25D16r0AAfR5cFAIDD2Py//2fOnFFMTIykkvE9Z86ckSRdd911WrlypX2rw1XZdSJLQ/+7RgdO56hRiJ8W/KmXmoUFOrosAAAcyubwExMTY31ie9u2bfX5559LKukRCg0NtWtxuHJr9qdp2NtJOpVdoNbhdfTlo73VOjzI0WUBAOBwNoef0aNHa+vWrZKkiRMnaubMmfLz89OTTz6pp59+2u4Fwnbfbj2hUfPWK7ugWD2a19PCP/dW41B/R5cFAIBTuOp5fg4fPqzk5GS1bNlSnTt3tldd1aa2z/Pz2fojmrhouyTp1k4Rmn5PV/l5ezq4KgAAro7DnupeVFSkm266Sfv27bMui46O1tChQ10i+NR2m4+c1eRvdkiSRvWK1lvDryH4AADwOzbd7eXt7a1t27ZVVy24CmdzCjX+k80qMhsa1DFCL9zegXl8AAAoh81jfu6//37NmTOnOmrBFbJYDCV8vkXHM/LUrH6AXvljZ4IPAAAVsHmen+LiYs2dO1c//vijYmNjFRhY+tbp6dOn2604VM2snw/op72n5ePloZkjrlGwn7ejSwIAwGnZHH527Niha665RpL066+/lnqP3oaal3QgXa8t3StJeun2DurQOMTBFQEA4NxsDj8//fRTddSBK3AqO1+Pf7pZFkMaek2khl0b5eiSAABwejzgyUWZLYYmfLpFaedKJjH8x5CO9LwBAFAFNvf83HjjjZWeZJcvX35VBaFq/rN8v5IOpivAx1P/HRGrAB+bDyUAAG7J5jNm165dS/1cVFSkLVu2aMeOHRo1apS96kIlkg6k643EkvFW/xjSUS0b1nFwRQAAuA6bw8/rr79e7vIXXnhB586du+qCULn0cwWa8FnJOJ8/xjbR0GuaOLokAABcit3G/Nx///2aO3euvTaHclgshp5auFWnsgvUokGgXrqjg6NLAgDA5dgt/CQlJcnPz89em0M53v3loFbsPS3f8/P5MM4HAADb2Xz2HDp0aKmfDcPQyZMntXHjRk2ePNluhaG0Hccz9eoPJfP5vHB7B7WNqH0PZQUAoCbYHH5CQkpPoufh4aE2bdropZde0oABA+xWGC4yDEMvfLtTxZaS53bdy3w+AABcMZvDz7x586qjDlTi260ntPHwWfl7e2rK4PbM5wMAwFWweczPhg0btG7dujLL161bp40bN9qlKFyUU1Csad/vkSSNu7GFGoX4O7giAABcm83hZ9y4cTp69GiZ5cePH9e4cePsUhQu+u+K/UrJyldUPX+N6Rvj6HIAAHB5NoefXbt2WR9seqlu3bpp165ddikKJQ6n5+jdlYckSc/f1l5+3p4OrggAANdnc/jx9fVVampqmeUnT56Ulxe3XtvTPxbvVqHZor6twjSgfbijywEAoFawOfwMGDBAkyZNUmZmpnVZRkaG/va3v+nmm2+2a3HubONvZ7RsV6q8PEyayiBnAADsxuaumn//+9/q16+foqOj1a1bN0nSli1bFB4erg8//NDuBbqrt1celFTyCIuWDYMcXA0AALWHzeEnMjJS27Zt08cff6ytW7fK399fo0eP1vDhw+Xt7V0dNbqdA6fP6cfdJZcWGeQMAIB9XdEgncDAQD3yyCP2rgXnzVl1SIYhxbdryBPbAQCwM5vH/EybNq3cB5jOnTtXr7zyil2Kcmdp5wr0ZfIxSdIj/Vo4uBoAAGofm8PP22+/rbZt25ZZ3qFDB82ePdsuRbmzD5IOq6DYoi5Robq2WV1HlwMAQK1jc/hJSUlRo0aNyixv0KCBTp48aZei3FVeoVkfJv0mSXqkbwx3eAEAUA1sDj9RUVFavXp1meWrV69W48aN7VKUu/py0zGdzS1SVD1/DezAvD4AAFQHmwc8jx07Vk888YSKiorUv39/SVJiYqKeeeYZPfXUU3Yv0J18u+WEJGlkz2by8rQ5lwIAgCqwOfw8/fTTSk9P12OPPabCwkJJkp+fn5599llNmjTJ7gW6i7RzBdp4+IwkaVCnCAdXAwBA7WVz+DGZTHrllVc0efJk7d69W/7+/mrVqpV8fX2roz63sXz3KVkMqWNksJrUDXB0OQAA1FpX/DCuOnXq6Nprr7VnLW7th50pkqQB7en1AQCgOl1R+Nm4caM+//xzHTlyxHrp64JFixbZpTB3klNQrF/2p0mSBjDQGQCAamXzqNrPPvtMvXv31u7du/XVV1+pqKhIO3fu1PLlyxUSElIdNdZ6K389rcJii6LrB6hNOM/xAgCgOtkcfl5++WW9/vrr+u677+Tj46M33nhDe/bs0T333KOmTZtWR4213tJdJc/xGtA+nLl9AACoZjaHnwMHDui2226TJPn4+CgnJ0cmk0lPPvmk3nnnHbsXWNsVmS1KPP8Q0wEdGO8DAEB1szn81K1bV9nZ2ZJKnvC+Y8cOSVJGRoZyc3PtW50bWHfwjLLyi1U/0EfXNOVxFgAAVDebBzz369dPy5YtU6dOnXT33XdrwoQJWr58uZYtW6abbrqpOmqs1S7c5XVz+3B5enDJCwCA6mZz+PnPf/6j/Px8SdJzzz0nb29vrVmzRnfddZeef/55uxdYm2XnF+nrzcclSbd05JIXAAA1webwU69ePevfPTw8NHHiRLsW5E4+XX9E2QXFatmwjvq1auDocgAAcAs8QMpBCostmrPqkCTpkX4x8uCSFwAANYLw4yDfbDmu1KwChQf76o6ujR1dDgAAboPw4wAWi6F3Vh6UJD3Up7l8vTwdXBEAAO6D8OMAP/96WvtOnVMdXy8Nj2NiSAAAapLN4eenn36q8L2ZM2deVTHuYtH5O7zu7t5EwX7eDq4GAAD3YnP4GTp0qJKTk8ssf+ONNzRp0iS7FFWb5RWarTM6396FsT4AANQ0m8PPq6++qkGDBmnPnj3WZa+99pqmTJmixYsX27W42mj5nlPKLTSrSV1/dY0KdXQ5AAC4HZvn+RkzZozOnDmj+Ph4rVq1SgsWLNDLL7+s77//Xn369KmOGmuV/207IUm6rXMjHmIKAIAD2Bx+JOmZZ55Renq6unfvLrPZrB9++EE9e/a0d221Tk5BsZbvOSVJGtyZS14AADhClcLPm2++WWZZZGSkAgIC1K9fP61fv17r16+XJP3lL3+xb4W1yI+7U1VQbFF0/QB1aBzs6HIAAHBLJsMwjMs1at68edU2ZjLp4MGDNhcxc+ZMvfrqq0pJSVGXLl301ltvqUePHuW2veGGG/Tzzz+XWX7rrbdWacxRVlaWQkJClJmZqeDgmg0gI95bq9X70zX+xpb668A2NfrZAAC4Mnuev6vU83Po0KGr+pDKLFiwQAkJCZo9e7bi4uI0Y8YMDRw4UHv37lXDhg3LtF+0aJEKCwutP6enp6tLly66++67q61Ge/g1NVur96fLwyTd2yPK0eUAAOC2HD7J4fTp0zV27FiNHj1a7du31+zZsxUQEKC5c+eW275evXqKiIiwvpYtW6aAgIAKw09BQYGysrJKvRxh/prfJEkD2keoSd0Ah9QAAACuIPzcddddeuWVV8os/9e//mVz70thYaGSk5MVHx9/sSAPD8XHxyspKalK25gzZ47uvfdeBQYGlvv+tGnTFBISYn1FRdV8r0tmbpEWbTomSXqwT7Ma/3wAAHCRzeFn5cqVuvXWW8ssHzRokFauXGnTttLS0mQ2mxUeHl5qeXh4uFJSUi67/vr167Vjxw6NGTOmwjaTJk1SZmam9XX06FGbarSH/9txUvlFFrWNCFJc83o1/vkAAOAim291P3funHx8fMos9/b2rvFLSnPmzFGnTp0qHBwtSb6+vvL19a3Bqspaf+iMJCm+XThz+wAA4GA29/x06tRJCxYsKLP8s88+U/v27W3aVlhYmDw9PZWamlpqeWpqqiIiIipdNycnR5999pkefvhhmz7TEdb/VhJ+etDrAwCAw9nc8zN58mQNHTpUBw4cUP/+/SVJiYmJ+vTTT7Vw4UKbtuXj46PY2FglJiZqyJAhkiSLxaLExESNHz++0nUXLlyogoIC3X///bbuQo06kZGnY2fz5GGSromu6+hyAABwezaHn8GDB+vrr7/Wyy+/rC+++EL+/v7q3LmzfvzxR11//fU2F5CQkKBRo0ape/fu6tGjh2bMmKGcnByNHj1akjRy5EhFRkZq2rRppdabM2eOhgwZovr169v8mTVpw/len46RIarje0UTagMAADu6orPxbbfdpttuu80uBQwbNkynT5/WlClTlJKSoq5du2rJkiXWQdBHjhyRh0fpq3N79+7VqlWrtHTpUrvUUJ3WnR/v06MZl7wAAHAGVZrhuTap6RmeB7z+s35NPae3H4jVwA6Vj2MCAADlq/EZni9lNpv1+uuv6/PPP9eRI0dKzbYsSWfOnLmqgmqT7Pwi7Tt1TpJ0TVPG+wAA4AxsvtvrxRdf1PTp0zVs2DBlZmYqISFBQ4cOlYeHh1544YVqKNF1bT+WKcOQIkP91SDIsbfbAwCAEjaHn48//ljvvvuunnrqKXl5eWn48OF67733NGXKFK1du7Y6anRZm49mSJK6Ng11aB0AAOAim8NPSkqKOnXqJEmqU6eOMjMzJUl/+MMfqvRUdXey5Xz46RYV6tA6AADARTaHnyZNmujkyZOSpBYtWljvuNqwYYPDZ1J2Nlsv9PwQfgAAcBo2h58777xTiYmJkqTHH39ckydPVqtWrTRy5Eg99NBDdi/QVZ3OLtCp7AKZTFKHxiGOLgcAAJxn891e//znP61/HzZsmKKjo7VmzRq1atVKgwcPtmtxrmxvSrYkqXn9QPn7eDq4GgAAcIHN4WflypXq3bu3vLxKVu3Zs6d69uyp4uJirVy5Uv369bN7ka5oT0rJQ17bRAQ5uBIAAHApmy973XjjjeXO5ZOZmakbb7zRLkXVBnvO9/y0jaj+iRQBAEDV2Rx+DMOQyWQqszw9PV2BgYF2Kao2oOcHAADnVOXLXkOHDpUkmUwmPfjgg6Xu7DKbzdq2bZt69+5t/wpdkNliaF9qyczOhB8AAJxLlcNPSEjJHUuGYSgoKEj+/v7W93x8fNSzZ0+NHTvW/hW6oCNnclVQbJGft4ea1gtwdDkAAOASVQ4/8+bNkyQ1a9ZMTz/9tAICOKlX5MKdXq3Dg+TpUfYSIQAAcBybx/z8/PPPZR5mKpU8bbV///52KcrV/ZpaEn5aNeSSFwAAzsZu4Sc/P1+//PKLXYpydb+l50iSYhowABwAAGdT5cte27Ztk1Qy5mfXrl1KSUmxvmc2m7VkyRJFRkbav0IXdOxMniSpSV3/y7QEAAA1rcrhp2vXrjKZTDKZTOVe3vL399dbb71l1+Jc1bGzuZKkKAY7AwDgdKocfg4dOiTDMBQTE6P169erQYMG1vd8fHzUsGFDeXryGIfCYotOZuVLkqLqEn4AAHA2VQ4/0dHRkiSLxVJtxdQGqVn5MgzJx9NDYXV8HF0OAAD4HZuf7XXBrl27dOTIkTKDn2+//farLsqVZeYVSZLqBnqXOxM2AABwLJvDz8GDB3XnnXdq+/btMplMMgxDkqwnerPZbN8KXUxGbkn4CfWn1wcAAGdk863uEyZMUPPmzXXq1CkFBARo586dWrlypbp3764VK1ZUQ4mu5WxuSU9YSIC3gysBAADlsbnnJykpScuXL1dYWJg8PDzk4eGh6667TtOmTdNf/vIXbd68uTrqdBkZFy57EX4AAHBKNvf8mM1mBQWVzFwcFhamEydOSCoZEL137177VueCMs/3/HDZCwAA52Rzz0/Hjh21detWNW/eXHFxcfrXv/4lHx8fvfPOO4qJiamOGl3KmZzzY34C6fkBAMAZ2Rx+nn/+eeXklDy+4aWXXtIf/vAH9e3bV/Xr19eCBQvsXqCrOX2uQJLUoI6vgysBAADlsTn8DBw40Pr3li1bas+ePTpz5ozq1q3Lrd2S0rJLwk8Y4QcAAKd0xfP8XKpevXr22EytkHaO8AMAgDOzecAzKncmp2TAc31mdwYAwCkRfuzIMAxl5ZcMeA7xZ8AzAADOiPBjRwXFFhWZS2a8Dib8AADglAg/dpR1foJDD5MU6MMT7gEAcEaEHzu6cMkr2J+HmgIA4KwIP3aUlV8sSarja5eb6AAAQDUg/NhR9vnwE+THeB8AAJwV4ceOss9f9gryo+cHAABnRfixo3MXen647AUAgNMi/NjRuYKS8BNI+AEAwGkRfuyooNgiSfLz5p8VAABnxVnajvKLzJIkP2/m+AEAwFkRfuzoYs8P4QcAAGdF+LEja8+PF/+sAAA4K87SdlRQVNLz40P4AQDAaXGWtqMiM+EHAABnx1najgouhB9P/lkBAHBWnKXtqOj8gGdven4AAHBanKXt6MJlL296fgAAcFqcpe2okMteAAA4Pc7SdnQ2p+TBpiEBPNUdAABnRfixo7zz8/zwYFMAAJwX4ceOii0ll708PUwOrgQAAFSE8GNHZrMhSfLy4J8VAABnxVnajootJeGHnh8AAJwX4ceOzOfDj5cn4QcAAGdF+LGjCz0/HibCDwAAzorwY0eF52d49mWGZwAAnBZnaTtihmcAAJwfZ2k7KmbMDwAATo/wYyeGYVj/zpgfAACcF+HHTiwXs4+40x0AAOdF+LET8yXpx0TPDwAATovwYyeWUpe9HFgIAAColMPDz8yZM9WsWTP5+fkpLi5O69evr7R9RkaGxo0bp0aNGsnX11etW7fW999/X0PVVswoddmL9AMAgLNy6OPHFyxYoISEBM2ePVtxcXGaMWOGBg4cqL1796phw4Zl2hcWFurmm29Ww4YN9cUXXygyMlKHDx9WaGhozRf/OxceairxeAsAAJyZQ8PP9OnTNXbsWI0ePVqSNHv2bC1evFhz587VxIkTy7SfO3euzpw5ozVr1sjb21uS1KxZs5osuULF5otdP8zzAwCA83LYWbqwsFDJycmKj4+/WIyHh+Lj45WUlFTuOt9++6169eqlcePGKTw8XB07dtTLL78ss9lc4ecUFBQoKyur1Ks6FF8y4JmeHwAAnJfDwk9aWprMZrPCw8NLLQ8PD1dKSkq56xw8eFBffPGFzGazvv/+e02ePFmvvfaa/vGPf1T4OdOmTVNISIj1FRUVZdf9uODCPD/kHgAAnJtLXZ+xWCxq2LCh3nnnHcXGxmrYsGF67rnnNHv27ArXmTRpkjIzM62vo0ePVk9t5zt+GOwMAIBzc9iYn7CwMHl6eio1NbXU8tTUVEVERJS7TqNGjeTt7S1PT0/rsnbt2iklJUWFhYXy8fEps46vr698fX3tW3w5LtzqTvYBAMC5Oaznx8fHR7GxsUpMTLQus1gsSkxMVK9evcpdp0+fPtq/f78sl9xZ9euvv6pRo0blBp+adGGSQ3p+AABwbg697JWQkKB3331X77//vnbv3q1HH31UOTk51ru/Ro4cqUmTJlnbP/roozpz5owmTJigX3/9VYsXL9bLL7+scePGOWoXrC70/DDYGQAA5+bQW92HDRum06dPa8qUKUpJSVHXrl21ZMkS6yDoI0eOyMPjYj6LiorSDz/8oCeffFKdO3dWZGSkJkyYoGeffdZRu2DFmB8AAFyDybj0ceRuICsrSyEhIcrMzFRwcLDdtnvw9Dn1f+1nBfl5afsLA+22XQAAYN/zt0vd7eXMLiRI+n0AAHBuhB87udB/xhPdAQBwboQfu+FWdwAAXAHhx06sPT+OLQMAAFwG4cdOrGN+6PoBAMCpEX7shJ4fAABcA+HHTgzG/AAA4BIIP3ZycbYk0g8AAM6M8GMnFx5vwdMtAABwboQfO7k4z49j6wAAAJUj/NiZicteAAA4NcIPAABwK4QfO+OyFwAAzo3wYycX7/YCAADOjPBjJ9Z5fhxcBwAAqBzhx054qjsAAK6B8AMAANwK4QcAALgVwo+dMN4ZAADXQPgBAABuhfADAADcCuHHzrjZCwAA50b4AQAAboXwAwAA3Arhx04Mnm8BAIBLIPwAAAC3QvgBAABuhfADAADcCuHHzrjVHQAA50b4AQAAboXwAwAA3ArhBwAAuBXCDwAAcCuEHzthikMAAFwD4QcAALgVwg8AAHArhB8AAOBWCD8AAMCtEH4AAIBbIfwAAAC3QvgBAABuhfBjJ7kFZknS0TN5Dq4EAABUhvBjJ4u3n3B0CQAAoAoIP3Zye5dIR5cAAACqgPBjJ77eJf+UTesFOLgSAABQGcIPAABwK4QfAADgVgg/AADArRB+AACAWyH8AAAAt0L4AQAAboXwAwAA3ArhBwAAuBXCDwAAcCuEHwAA4FYIPwAAwK0QfgAAgFsh/AAAALdC+LGzrPwiR5cAAAAqQfixk19TsiVJGbmEHwAAnBnhx04OpeU4ugQAAFAFhB87GdAhwtElAACAKiD82InJVPJn03oBji0EAABUyinCz8yZM9WsWTP5+fkpLi5O69evr7Dt/PnzZTKZSr38/PxqsNryncjIkyQdOZPr4EoAAEBlHB5+FixYoISEBE2dOlWbNm1Sly5dNHDgQJ06darCdYKDg3Xy5Enr6/DhwzVYcfnGf7LZ0SUAAIAqcHj4mT59usaOHavRo0erffv2mj17tgICAjR37twK1zGZTIqIiLC+wsPDa7Di8k0c1NbRJQAAgCpwaPgpLCxUcnKy4uPjrcs8PDwUHx+vpKSkCtc7d+6coqOjFRUVpTvuuEM7d+6ssG1BQYGysrJKvarDn69vod/+eZt+++dt1bJ9AABgHw4NP2lpaTKbzWV6bsLDw5WSklLuOm3atNHcuXP1zTff6KOPPpLFYlHv3r117NixcttPmzZNISEh1ldUVJTd9wMAALgOh1/2slWvXr00cuRIde3aVddff70WLVqkBg0a6O233y63/aRJk5SZmWl9HT16tIYrBgAAzsTLkR8eFhYmT09PpaamllqempqqiIiqzZvj7e2tbt26af/+/eW+7+vrK19f36uuFQAA1A4O7fnx8fFRbGysEhMTrcssFosSExPVq1evKm3DbDZr+/btatSoUXWVCQAAahGH9vxIUkJCgkaNGqXu3burR48emjFjhnJycjR69GhJ0siRIxUZGalp06ZJkl566SX17NlTLVu2VEZGhl599VUdPnxYY8aMceRuAAAAF+Hw8DNs2DCdPn1aU6ZMUUpKirp27aolS5ZYB0EfOXJEHh4XO6jOnj2rsWPHKiUlRXXr1lVsbKzWrFmj9u3bO2oXAACACzEZhmE4uoialJWVpZCQEGVmZio4ONjR5QAAgCqw5/nb5e72AgAAuBqEHwAA4FYIPwAAwK0QfgAAgFsh/AAAALdC+AEAAG6F8AMAANyKwyc5rGkXpjXKyspycCUAAKCqLpy37TE9oduFn+zsbElSVFSUgysBAAC2ys7OVkhIyFVtw+1meLZYLDpx4oSCgoJkMpnsuu2srCxFRUXp6NGjtXr2aPaz9nGXfWU/axd32U/Jffa1sv00DEPZ2dlq3LhxqcdeXQm36/nx8PBQkyZNqvUzgoODa/Uv5wXsZ+3jLvvKftYu7rKfkvvsa0X7ebU9Phcw4BkAALgVwg8AAHArhB878vX11dSpU+Xr6+voUqoV+1n7uMu+sp+1i7vsp+Q++1pT++l2A54BAIB7o+cHAAC4FcIPAABwK4QfAADgVgg/AADArRB+bDRz5kw1a9ZMfn5+iouL0/r16yttv3DhQrVt21Z+fn7q1KmTvv/++xqq9MpMmzZN1157rYKCgtSwYUMNGTJEe/furXSd+fPny2QylXr5+fnVUMVX5oUXXihTc9u2bStdx9WO5QXNmjUrs68mk0njxo0rt72rHM+VK1dq8ODBaty4sUwmk77++utS7xuGoSlTpqhRo0by9/dXfHy89u3bd9nt2vodr26V7WdRUZGeffZZderUSYGBgWrcuLFGjhypEydOVLrNK/n9rwmXO6YPPvhgmbpvueWWy27XlY6ppHK/ryaTSa+++mqF23TGY1qV80l+fr7GjRun+vXrq06dOrrrrruUmppa6Xav9Lt9KcKPDRYsWKCEhARNnTpVmzZtUpcuXTRw4ECdOnWq3PZr1qzR8OHD9fDDD2vz5s0aMmSIhgwZoh07dtRw5VX3888/a9y4cVq7dq2WLVumoqIiDRgwQDk5OZWuFxwcrJMnT1pfhw8frqGKr1yHDh1K1bxq1aoK27risbxgw4YNpfZz2bJlkqS77767wnVc4Xjm5OSoS5cumjlzZrnv/+tf/9Kbb76p2bNna926dQoMDNTAgQOVn59f4TZt/Y7XhMr2Mzc3V5s2bdLkyZO1adMmLVq0SHv37tXtt99+2e3a8vtfUy53TCXplltuKVX3p59+Wuk2Xe2YSiq1fydPntTcuXNlMpl01113VbpdZzumVTmfPPnkk/ruu++0cOFC/fzzzzpx4oSGDh1a6Xav5LtdhoEq69GjhzFu3Djrz2az2WjcuLExbdq0ctvfc889xm233VZqWVxcnPGnP/2pWuu0p1OnThmSjJ9//rnCNvPmzTNCQkJqrig7mDp1qtGlS5cqt68Nx/KCCRMmGC1atDAsFku577vi8ZRkfPXVV9afLRaLERERYbz66qvWZRkZGYavr6/x6aefVrgdW7/jNe33+1me9evXG5KMw4cPV9jG1t9/RyhvX0eNGmXccccdNm2nNhzTO+64w+jfv3+lbVzhmP7+fJKRkWF4e3sbCxcutLbZvXu3IclISkoqdxtX+t3+PXp+qqiwsFDJycmKj4+3LvPw8FB8fLySkpLKXScpKalUe0kaOHBghe2dUWZmpiSpXr16lbY7d+6coqOjFRUVpTvuuEM7d+6sifKuyr59+9S4cWPFxMRoxIgROnLkSIVta8OxlEp+jz/66CM99NBDlT7Y1xWP56UOHTqklJSUUscsJCREcXFxFR6zK/mOO6PMzEyZTCaFhoZW2s6W339nsmLFCjVs2FBt2rTRo48+qvT09Arb1oZjmpqaqsWLF+vhhx++bFtnP6a/P58kJyerqKio1PFp27atmjZtWuHxuZLvdnkIP1WUlpYms9ms8PDwUsvDw8OVkpJS7jopKSk2tXc2FotFTzzxhPr06aOOHTtW2K5NmzaaO3euvvnmG3300UeyWCzq3bu3jh07VoPV2iYuLk7z58/XkiVLNGvWLB06dEh9+/ZVdnZ2ue1d/Vhe8PXXXysjI0MPPvhghW1c8Xj+3oXjYssxu5LvuLPJz8/Xs88+q+HDh1f68Etbf/+dxS233KIPPvhAiYmJeuWVV/Tzzz9r0KBBMpvN5bavDcf0/fffV1BQ0GUvBTn7MS3vfJKSkiIfH58yQf1y59ULbaq6Tnnc7qnuqLpx48Zpx44dl71u3KtXL/Xq1cv6c+/evdWuXTu9/fbb+vvf/17dZV6RQYMGWf/euXNnxcXFKTo6Wp9//nmV/g/LVc2ZM0eDBg1S48aNK2zjiscTJYOf77nnHhmGoVmzZlXa1lV//++9917r3zt16qTOnTurRYsWWrFihW666SYHVlZ95s6dqxEjRlz2pgNnP6ZVPZ/UFHp+qigsLEyenp5lRqGnpqYqIiKi3HUiIiJsau9Mxo8fr//973/66aef1KRJE5vW9fb2Vrdu3bR///5qqs7+QkND1bp16wprduVjecHhw4f1448/asyYMTat54rH88JxseWYXcl33FlcCD6HDx/WsmXLKu31Kc/lfv+dVUxMjMLCwiqs25WPqST98ssv2rt3r83fWcm5jmlF55OIiAgVFhYqIyOjVPvLnVcvtKnqOuUh/FSRj4+PYmNjlZiYaF1msViUmJhY6v+SL9WrV69S7SVp2bJlFbZ3BoZhaPz48frqq6+0fPlyNW/e3OZtmM1mbd++XY0aNaqGCqvHuXPndODAgQprdsVj+Xvz5s1Tw4YNddttt9m0nisez+bNmysiIqLUMcvKytK6desqPGZX8h13BheCz759+/Tjjz+qfv36Nm/jcr//zurYsWNKT0+vsG5XPaYXzJkzR7GxserSpYvN6zrDMb3c+SQ2Nlbe3t6ljs/evXt15MiRCo/PlXy3KyoOVfTZZ58Zvr6+xvz5841du3YZjzzyiBEaGmqkpKQYhmEYDzzwgDFx4kRr+9WrVxteXl7Gv//9b2P37t3G1KlTDW9vb2P79u2O2oXLevTRR42QkBBjxYoVxsmTJ62v3Nxca5vf7+eLL75o/PDDD8aBAweM5ORk49577zX8/PyMnTt3OmIXquSpp54yVqxYYRw6dMhYvXq1ER8fb4SFhRmnTp0yDKN2HMtLmc1mo2nTpsazzz5b5j1XPZ7Z2dnG5s2bjc2bNxuSjOnTpxubN2+23uX0z3/+0wgNDTW++eYbY9u2bcYdd9xhNG/e3MjLy7Nuo3///sZbb71l/fly33FHqGw/CwsLjdtvv91o0qSJsWXLllLf2YKCAus2fr+fl/v9d5TK9jU7O9v461//aiQlJRmHDh0yfvzxR+Oaa64xWrVqZeTn51u34erH9ILMzEwjICDAmDVrVrnbcIVjWpXzyZ///GejadOmxvLly42NGzcavXr1Mnr16lVqO23atDEWLVpk/bkq3+3LIfzY6K233jKaNm1q+Pj4GD169DDWrl1rfe/66683Ro0aVar9559/brRu3drw8fExOnToYCxevLiGK7aNpHJf8+bNs7b5/X4+8cQT1n+T8PBw49ZbbzU2bdpU88XbYNiwYUajRo0MHx8fIzIy0hg2bJixf/9+6/u14Vhe6ocffjAkGXv37i3znqsez59++qnc39UL+2KxWIzJkycb4eHhhq+vr3HTTTeV2f/o6Ghj6tSppZZV9h13hMr289ChQxV+Z3/66SfrNn6/n5f7/XeUyvY1NzfXGDBggNGgQQPD29vbiI6ONsaOHVsmxLj6Mb3g7bffNvz9/Y2MjIxyt+EKx7Qq55O8vDzjscceM+rWrWsEBAQYd955p3Hy5Mky27l0nap8ty/HdH7DAAAAboExPwAAwK0QfgAAgFsh/AAAALdC+AEAAG6F8AMAANwK4QcAALgVwg8AAHArhB8AAOBWCD8Aasxvv/0mk8mkLVu2OLoUm82fP1+hoaF2bwug5hF+ANSYqKgonTx5Uh07dnR0KTYbNmyYfv31V0eXAcAOvBxdAAD3UFhYKB8fH0VERDi6FJsVFRXJ399f/v7+ji4FgB3Q8wPAZjfccIPGjx+v8ePHKyQkRGFhYZo8ebIufVRgs2bN9Pe//10jR45UcHCwHnnkkVKXvSwWi5o0aaJZs2aV2vbmzZvl4eGhw4cPS5KmT5+uTp06KTAwUFFRUXrsscd07ty5UuusXr1aN9xwgwICAlS3bl0NHDhQZ8+e1QcffKD69euroKCgVPshQ4bogQceKHffLtS4YMECXX/99fLz89PHH39c5lLW1q1bdeONNyooKEjBwcGKjY3Vxo0by93m6dOn1b17d915551lagFQ8wg/AK7I+++/Ly8vL61fv15vvPGGpk+frvfee69Um3//+9/q0qWLNm/erMmTJ5d6z8PDQ8OHD9cnn3xSavnHH3+sPn36KDo62truzTff1M6dO/X+++9r+fLleuaZZ6ztt2zZoptuuknt27dXUlKSVq1apcGDB8tsNuvuu++W2WzWt99+a21/6tQpLV68WA899FCl+zdx4kRNmDBBu3fv1sCBA8u8P2LECDVp0kQbNmxQcnKyJk6cKG9v7zLtjh49qr59+6pjx4764osv5OvrW+nnAqgBV/ysegBu6/rrrzfatWtnWCwW67Jnn33WaNeunfXn6OhoY8iQIaXWO3TokCHJ2Lx5s2EYhrF582bDZDIZhw8fNgzDMMxmsxEZGWnMmjWrws9euHChUb9+fevPw4cPN/r06VNh+0cffdQYNGiQ9efXXnvNiImJKVV7eTXOmDGj1PJ58+YZISEh1p+DgoKM+fPnl7uNC2337NljREVFGX/5y18q/DwANY+eHwBXpGfPnjKZTNafe/XqpX379slsNluXde/evdJtdO3aVe3atbP2/vz88886deqU7r77bmubH3/8UTfddJMiIyMVFBSkBx54QOnp6crNzZV0seenImPHjtXSpUt1/PhxSSV3Yj344IOlai/P5WpPSEjQmDFjFB8fr3/+8586cOBAqffz8vLUt29fDR06VG+88cZlPw9AzSH8AKg2gYGBl20zYsQIa/j55JNPdMstt6h+/fqSSsbf/OEPf1Dnzp315ZdfKjk5WTNnzpRUMoBa0mUHIXfr1k1dunTRBx98oOTkZO3cuVMPPvjgVdf+wgsvaOfOnbrtttu0fPlytW/fXl999ZX1fV9fX8XHx+t///ufNXgBcA6EHwBXZN26daV+Xrt2rVq1aiVPT0+btnPfffdpx44dSk5O1hdffKERI0ZY30tOTpbFYtFrr72mnj17qnXr1jpx4kSp9Tt37qzExMRKP2PMmDGaP3++5s2bp/j4eEVFRdlUY0Vat26tJ598UkuXLtXQoUM1b94863seHh768MMPFRsbqxtvvLFM3QAch/AD4IocOXJECQkJ2rt3rz799FO99dZbmjBhgs3badasmXr37q2HH35YZrNZt99+u/W9li1bqqioSG+99ZYOHjyoDz/8ULNnzy61/qRJk7RhwwY99thj2rZtm/bs2aNZs2YpLS3N2ua+++7TsWPH9O677152oHNV5OXlafz48VqxYoUOHz6s1atXa8OGDWrXrl2pdp6envr444/VpUsX9e/fXykpKVf92QCuHuEHwBUZOXKk8vLy1KNHD40bN04TJkzQI488ckXbGjFihLZu3ao777yz1GWsLl26aPr06XrllVfUsWNHffzxx5o2bVqpdVu3bq2lS5dq69at6tGjh3r16qVvvvlGXl4XpzELCQnRXXfdpTp16mjIkCFXVOOlPD09lZ6erpEjR6p169a65557NGjQIL344otl2np5eenTTz9Vhw4d1L9/f506deqqPx/A1TEZxiUTcwBAFdxwww3q2rWrZsyY4ehSquymm25Shw4d9Oabbzq6FAAOxgzPAGq1s2fPasWKFVqxYoX++9//OrocAE6A8AOgVuvWrZvOnj2rV155RW3atHF0OQCcAJe9AACAW2HAMwAAcCuEHwAA4FYIPwAAwK0QfgAAgFsh/AAAALdC+AEAAG6F8AMAANwK4QcAALiV/w+aIdaL3oII1AAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "df=pd.DataFrame({'xvalues': X_axi, 'yvalues': Y_axi })\n",
    " \n",
    "# plot\n",
    "plt.plot( 'xvalues', 'yvalues', data=df)\n",
    "plt.xlabel(\"privacy risk\")\n",
    "plt.ylabel(\"attack accuracy\")\n",
    "# show the graph\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ee4f5ad-e959-4f63-bde3-78c7e25fc0e4",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48d74a89-2c75-453a-a3b2-3fd906337030",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "025ed341-77fd-4f0e-b2cf-1ac59f0e74af",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "8e1a4ee1-0c47-411a-b98b-d7713d949e09",
   "metadata": {},
   "source": [
    "### 绘制损失分布差异"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "id": "ecd477a2-9df6-4a4a-ba78-716c2d669d45",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 101,
   "id": "2abddba4-ae48-405a-beb5-c45daf231d6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "idx = 31000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 102,
   "id": "0421a068-0678-435b-86b7-9f37862cc816",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 需要使用loss_data_all，pri_risk_rank，train_keep\n",
    "dat_in = []\n",
    "dat_out = []\n",
    "for i in range(loss_data_all.shape[1]):\n",
    "    dat_in.append((loss_data_all[train_keep[:,i],i]))\n",
    "    dat_out.append((loss_data_all[~train_keep[:,i],i]))\n",
    "dat_in = np.array(dat_in)\n",
    "dat_out = np.array(dat_out)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 103,
   "id": "37c2030e-d91f-4fdf-9a5a-e246ca594916",
   "metadata": {},
   "outputs": [],
   "source": [
    "mem1 = dat_in[pri_risk_rank[idx]]\n",
    "non_mem1 = dat_out[pri_risk_rank[idx]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "id": "5612c786-26df-4ff0-9c42-2af62fa46b0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "mem1 = mem1.reshape(mem1.shape[0], 1)\n",
    "non_mem1 = non_mem1.reshape(non_mem1.shape[0], 1)\n",
    "\n",
    "arr = np.concatenate((mem1, non_mem1), 1)\n",
    "\n",
    "df = pd.DataFrame(arr, columns=['loss','out'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 105,
   "id": "5a3fd823-c33a-4c78-8829-e7bb6410bc02",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjIAAAG1CAYAAADjkR6kAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/xnp5ZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAmCUlEQVR4nO3df3DU9Z3H8dfm52YhCZDw24REBtINipEfYZRqoXJQLFaunV6dEYv0ij1KRUxFjTUi1JqWq5jRIhbuKnbUw2ur1HGQHqZ4tICCeFSRFeEMhuFXukLYkDUh2f3eHzZ7jSQh2ezu9/sJz8fMTtnv/vi+v/tt4Ol+v7txWZZlCQAAwEBJdg8AAAAQLUIGAAAYi5ABAADGImQAAICxCBkAAGAsQgYAABiLkAEAAMYiZAAAgLFS7B4g3sLhsI4fP67MzEy5XC67xwEAAN1gWZYaGho0YsQIJSV1/r5Lnw+Z48ePKy8vz+4xAABAFI4eParLLrus09v7fMhkZmZK+uyFyMrKsnkaAADQHYFAQHl5eZF/xzvT50Om7XBSVlYWIQMAgGEudloIJ/sCAABjETIAAMBYff7QEgAAdguFQmppabF7DEdJTU1VcnJyr5+HkAEAIE4sy9LJkydVX19v9yiONGDAAA0bNqxXX49CyAAAECdtETNkyBB5PB6+z+xvLMtSMBhUXV2dJGn48OFRPxchAwBAHIRCoUjE5OTk2D2O42RkZEiS6urqNGTIkKgPM3GyLwAAcdB2TozH47F5Eudqe216c/4QIQMAQBxxOKlzsXhtCBkAAGAszpEBACDBamtr5ff7E7a+3Nxc5efnJ2x9iUTIAACQQLW1tSoq8qqpKZiwdbrdHh086Ot2zEybNk0lJSWqqqqK72AxQMgAAJBAfr9fTU1Beb3PyePxxn19waBPPt88+f3+bofMSy+9pNTU1DhPFhuEDAAANvB4vMrMnGD3GB0aNGiQ3SN0GyEDfE6ij13bpS8fMwfQO39/aKmgoEB33HGHDh8+rN/85jcaOHCgHnzwQd1xxx12jymJkAHasePYtV16eswcwKXrscce049//GM98MAD+u1vf6tFixbpS1/6koqKiuwejZAB/l6ij13bJZpj5gAuXTfeeKO+//3vS5Luu+8+Pf7449q2bRshAziVk49dA0CijR8/PvJnl8ulYcOGRX5Pkt34QjwAANClz3+CyeVyKRwO2zRNe4QMAAAwFoeWAACwQTDo61PrsQshAwBAAuXm5srt9sjnm5ewdbrdHuXm5iZsfYlEyAAAkED5+fk6eNDn6N+19MYbb0T+fOTIkQtu37dvX++HihFCBgCABMvPz+erD2KEk30BAICxCBkAAGAsQgYAABiLkAEAAMYiZAAAgLEIGQAAYCxCBgAAGIvvkQEAIMFqa2sd/YV4JiFkAABIoNraWnmLihRsakrYOj1ut3wHDyYsZh5++GFt2rQpId8ATMgAAJBAfr9fwaYmPef1yuvxxH19vmBQ83yf/UqEvviuDCEDAIANvB6PJmRm2j1Gh5qbm7Vs2TJt3LhRgUBAkyZN0uOPP67Jkydrw4YNWrp0qerr6yP337Rpk/7xH/9RlmVpw4YNWrFihSTJ5XJJkp555hndfvvtcZmVkAEAAO3ce++9+t3vfqdnn31Wo0aN0qpVqzRr1iwdPnz4oo/91re+pf3792vLli16/fXXJUnZ2dlxm5WQAQAAEY2NjVq7dq02bNig2bNnS5LWr1+vrVu36t///d81ePDgLh+fkZGh/v37KyUlRcOGDYv7vHz8GgAARPzv//6vWlpaNHXq1Miy1NRUlZaWyufz2ThZxwgZAADQbUlJSbIsq92ylpYWm6YhZAAAwN8ZPXq00tLStGPHjsiylpYW7dmzR8XFxRo8eLAaGhrU2NgYuf3zH7NOS0tTKBRKyLycIwMAACL69eunRYsWadmyZRo0aJDy8/O1atUqBYNB/fM//7Msy5LH49EDDzygJUuW6K233tKGDRvaPUdBQYFqamq0b98+XXbZZcrMzFR6enpc5iVkAACwgS8YdOx6fvrTnyocDuu2225TQ0ODJk2apD/84Q8aOHCgJOm5557TsmXLtH79et1www16+OGHdccdd0Qe/41vfEMvvfSSpk+frvr6ej5+DQBAX5GbmyuP2615CTxx1uN2Kzc3t9v3d7vdeuKJJ/TEE090ePvcuXM1d+7cdssWLlwY+XN6erp++9vfRjVrT9kaMtu3b9e//uu/au/evTpx4oRefvnldi+MZVlavny51q9fr/r6ek2dOlVr167VmDFj7BsaAIBeyM/Pl+/gQX7XUozYGjKNjY266qqr9J3vfEdf//rXL7h91apVeuKJJ/Tss8+qsLBQFRUVmjVrlg4cOCC3223DxAAA9F5+fn6fDYtEszVkZs+eHfmync+zLEtVVVV68MEHdfPNN0uSfv3rX2vo0KHatGmTbrnllkSOCgAAHMixH7+uqanRyZMnNWPGjMiy7OxsTZkyRbt27er0cc3NzQoEAu0uAACgb3JsyJw8eVKSNHTo0HbLhw4dGrmtI5WVlcrOzo5c8vLy4jonAACwj2NDJlrl5eU6e/Zs5HL06FG7RwIAXMI+/y24+H+xeG0cGzJtv2jq1KlT7ZafOnWqy19ClZ6erqysrHYXAAASLTU1VZIUTND3xZio7bVpe62i4djvkSksLNSwYcNUXV2tkpISSVIgENBbb72lRYsW2TscAAAXkZycrAEDBqiurk6S5PF45HK5bJ7KGSzLUjAYVF1dnQYMGKDk5OSon8vWkDl37pwOHz4cud72dcZtX4m8dOlSPfLIIxozZkzk49cjRoy44Et4AABworYjCG0xg/YGDBjQ5VGW7rA1ZN5++21Nnz49cr2srEySNH/+fG3YsEH33nuvGhsbdccdd6i+vl5f/OIXtWXLFr5DBgBgBJfLpeHDh2vIkCG2/oZoJ0pNTe3VOzFtbA2ZadOmdXmij8vl0sqVK7Vy5coETgUAQGwlJyfH5B9tXMixJ/sCAABcDCEDAACMRcgAAABjETIAAMBYhAwAADAWIQMAAIxFyAAAAGMRMgAAwFiEDAAAMBYhAwAAjEXIAAAAYxEyAADAWIQMAAAwFiEDAACMRcgAAABjETIAAMBYhAwAADAWIQMAAIxFyAAAAGMRMgAAwFiEDAAAMBYhAwAAjEXIAAAAYxEyAADAWCl2D2Cy2tpa+f1+u8dIiNzcXOXn59s9BgAA7RAyUaqtrVVRkVdNTUG7R0kIt9ujgwd9xAwAwFEImSj5/X41NQXl9T4nj8dr9zhxFQz65PPNk9/vJ2QAAI5CyPSSx+NVZuYEu8cAAOCSxMm+AADAWIQMAAAwFiEDAACMRcgAAABjETIAAMBYhAwAADAWIQMAAIxFyAAAAGMRMgAAwFiEDAAAMBYhAwAAjEXIAAAAYxEyAADAWIQMAAAwFiEDAACMRcgAAABjETIAAMBYhAwAADAWIQMAAIxFyAAAAGMRMgAAwFiEDAAAMBYhAwAAjEXIAAAAYxEyAADAWI4OmVAopIqKChUWFiojI0OjR4/Wj3/8Y1mWZfdoAADAAVLsHqArP/vZz7R27Vo9++yzGjdunN5++20tWLBA2dnZWrJkid3jAQAAmzk6ZHbu3Kmbb75ZX/3qVyVJBQUF+o//+A/t3r3b5skAAIATODpkrr32Wq1bt04ffvihxo4dq7/85S/685//rNWrV3f6mObmZjU3N0euBwKBRIx6yaqtrZXf77d7jJjx+XySpGDQ1637p6bmyu3Oj+dIcLi+9jMQjdzcXOXn83MAezg6ZO6//34FAgF94QtfUHJyskKhkH7yk5/o1ltv7fQxlZWVWrFiRQKnvHTV1tbKW1SkYFOT3aPEnM83r1v3S0lya1LpQWLmEtWXfwZ6wuN2y3fwIDEDWzg6ZP7zP/9Tzz//vF544QWNGzdO+/bt09KlSzVixAjNnz+/w8eUl5errKwscj0QCCgvLy9RI19S/H6/gk1Nes7rldfjsXucmGgMBnXA51O/DK+SkrvepppwUA8GfWpp8RMyl6i++DPQU75gUPN8Pvn9fkIGtnB0yCxbtkz333+/brnlFknSlVdeqY8//liVlZWdhkx6errS09MTOeYlz+vxaEJmpt1jxESDpFZJmckeJadcZJtaEzERTNCXfgYA0zj649fBYFBJSe1HTE5OVjgctmkiAADgJI5+R+amm27ST37yE+Xn52vcuHH6n//5H61evVrf+c537B4NAAA4gKND5sknn1RFRYW+//3vq66uTiNGjND3vvc9PfTQQ3aPBgAAHMDRIZOZmamqqipVVVXZPQoAAHAgR58jAwAA0BVCBgAAGIuQAQAAxiJkAACAsQgZAABgLEIGAAAYi5ABAADGImQAAICxCBkAAGAsQgYAABiLkAEAAMYiZAAAgLEIGQAAYCxCBgAAGIuQAQAAxiJkAACAsQgZAABgLEIGAAAYi5ABAADGImQAAICxCBkAAGAsQgYAABiLkAEAAMYiZAAAgLFS7B4AMF0w6LN7hB5rm9nn6/3szc3NSk9P7/XzxNOJEydUX18f8+etqamRJDUGg2qI+bNHJzU1VW632+4xLim1tbXy+/12j2Gb3Nxc5efn27Z+QgaIkj98XkmSfL55do8StXnzej97kqRw70cx2gGfT612D/E3SUlJmlJaSswkSG1trbxFRQo2Ndk9im08brd8Bw/aFjOEDBClBqtVYUkr0ws0OjXH7nF6JBwKqvFTn4q9XvXzeKJ+ns2ffKKKI0e0vqBAE3Kc+Ro0BoM64PPJ7S5Qsisjps+9o7VeT7eckNtdoMwU+7c/FA4qGPSppaWFkEkQv9+vYFOTnvN65e3Fz5KpfMGg5vl88vv9hAxgqsKkDHlTMu0eo0dCkhoklXg8ysyMfnZfMChJKsrI0IRePE88NUhqlZSZkqPkGO+njyWp5YSSXRkxf+6oOOVtoUuQ1+Nx7M9AX8fJvgAAwFiEDAAAMBYhAwAAjEXIAAAAYxEyAADAWIQMAAAwFiEDAACMRcgAAABjETIAAMBYhAwAADAWIQMAAIxFyAAAAGMRMgAAwFiEDAAAMBYhAwAAjEXIAAAAYxEyAADAWIQMAAAwVlQhc/nll+uTTz65YHl9fb0uv/zyXg8FAADQHVGFzJEjRxQKhS5Y3tzcrGPHjvV6KAAAgO5I6cmdX3nllcif//CHPyg7OztyPRQKqbq6WgUFBTEbDgAAoCs9Cpm5c+dKklwul+bPn9/uttTUVBUUFOixxx6L2XAAAABd6VHIhMNhSVJhYaH27Nmj3NzcuAwFAADQHT0KmTY1NTWxngMAAKDHogoZSaqurlZ1dbXq6uoi79S0+dWvftXrwQAAAC4mqk8trVixQjNnzlR1dbX8fr/OnDnT7hJLx44d07x585STk6OMjAxdeeWVevvtt2O6DgAAYKao3pF5+umntWHDBt12222xnqedM2fOaOrUqZo+fbpee+01DR48WIcOHdLAgQPjul4AAGCGqELm/Pnzuvbaa2M9ywV+9rOfKS8vT88880xkWWFhYdzXCwAAzBDVoaXvfve7euGFF2I9ywVeeeUVTZo0Sd/85jc1ZMgQXX311Vq/fn2Xj2lublYgEGh3AQAAfVNU78g0NTVp3bp1ev311zV+/Hilpqa2u3316tUxGe6jjz7S2rVrVVZWpgceeEB79uzRkiVLlJaWdsH32LSprKzUihUrYrJ+AADgbFGFzLvvvquSkhJJ0v79+9vd5nK5ej1Um3A4rEmTJunRRx+VJF199dXav3+/nn766U5Dpry8XGVlZZHrgUBAeXl5MZsJAAA4R1Qhs23btljP0aHhw4eruLi43TKv16vf/e53nT4mPT1d6enp8R4NAAA4QFTnyCTK1KlTdfDgwXbLPvzwQ40aNcqmiQAAgJNE9Y7M9OnTuzyE9Mc//jHqgf7e3XffrWuvvVaPPvqo/umf/km7d+/WunXrtG7dupg8PwAAMFtUIdN2fkyblpYW7du3T/v37+/03JVoTJ48WS+//LLKy8u1cuVKFRYWqqqqSrfeemvM1gEAAMwVVcg8/vjjHS5/+OGHde7cuV4N9Hlz5szRnDlzYvqcAACgb4jpOTLz5s3j9ywBAICEiWnI7Nq1S263O5ZPCQAA0KmoDi19/etfb3fdsiydOHFCb7/9tioqKmIyGAAAwMVEFTLZ2dntriclJamoqEgrV67UzJkzYzIYAADAxUQVMn//SxwBAADsElXItNm7d698Pp8kady4cbr66qtjMhQAAEB3RBUydXV1uuWWW/TGG29owIABkqT6+npNnz5dGzdu1ODBg2M5IwAAQIei+tTSnXfeqYaGBr3//vs6ffq0Tp8+rf379ysQCGjJkiWxnhEAAKBDUb0js2XLFr3++uvyer2RZcXFxVqzZg0n+wIAgISJ6h2ZcDis1NTUC5anpqYqHA73eigAAIDuiCpkvvzlL+uuu+7S8ePHI8uOHTumu+++WzfccEPMhgMAAOhKVCHzi1/8QoFAQAUFBRo9erRGjx6twsJCBQIBPfnkk7GeEQAAoENRnSOTl5end955R6+//ro++OADSZLX69WMGTNiOhwAAEBXevSOzB//+EcVFxcrEAjI5XLpH/7hH3TnnXfqzjvv1OTJkzVu3Dj96U9/itesAAAA7fQoZKqqqrRw4UJlZWVdcFt2dra+973vafXq1TEbDgAAoCs9Cpm//OUv+spXvtLp7TNnztTevXt7PRQAAEB39ChkTp061eHHrtukpKTor3/9a6+HAgAA6I4ehczIkSO1f//+Tm9/9913NXz48F4PBQAA0B09Cpkbb7xRFRUVampquuC2Tz/9VMuXL9ecOXNiNhwAAEBXevTx6wcffFAvvfSSxo4dqx/84AcqKiqSJH3wwQdas2aNQqGQfvSjH8VlUAAAgM/rUcgMHTpUO3fu1KJFi1ReXi7LsiRJLpdLs2bN0po1azR06NC4DAoAAPB5Pf5CvFGjRmnz5s06c+aMDh8+LMuyNGbMGA0cODAe8wEAAHQqqm/2laSBAwdq8uTJsZwFAACgR6L6XUsAAABOEPU7MvhMMOize4S4a9vGzZs3y+f7/+2tqamRJDUGg2qwZbLYawwG7R4BQDfU1tbK7/fbPUbk78R4/j2Ympoqt9sdp2c3HyETpRMnTihJks83z+5REqaioqLD5Qd8PrUmeJZ4C1uWku0eAkCHamtrVVTkVVOTc/7DI55/DyYlJWlKaSkx0wlCJkr19fUKS1qZXqDRqTl2jxNXLa2fqKnpiNLTRikl2RNZvqO1Xk+3nJDbXaDMlL7xGrS0nlZTU03kE3kAnMfv96upKSiv9zl5PF5bZwkGffL55qlfhleZf/f3Y6yEwkEFgz61tLQQMp0gZHqpMClD3pRMu8eIq/PhoIKSPCmDlJaWHVn+sSS1nFCyK0PJfeQ1CIWd8194ALrm8XiVmTnB7jEkSUnJnvj8PdjX3u6OA072BQAAxiJkAACAsQgZAABgLEIGAAAYi5ABAADGImQAAICxCBkAAGAsQgYAABiLkAEAAMYiZAAAgLEIGQAAYCxCBgAAGIuQAQAAxiJkAACAsQgZAABgLEIGAAAYi5ABAADGImQAAICxCBkAAGAsQgYAABiLkAEAAMYiZAAAgLEIGQAAYCxCBgAAGIuQAQAAxjIqZH7605/K5XJp6dKldo8CAAAcwJiQ2bNnj375y19q/Pjxdo8CAAAcwoiQOXfunG699VatX79eAwcOtHscAADgECl2D9Adixcv1le/+lXNmDFDjzzySJf3bW5uVnNzc+R6IBCI93iAsRqDwV49/tNPP438b0NDQyxGirnebiMAZ3N8yGzcuFHvvPOO9uzZ0637V1ZWasWKFXGeCjBb2DovSfL5fL16npq2/z1yRKlHjvRuqDgLW5aS7R4CQMw5OmSOHj2qu+66S1u3bpXb7e7WY8rLy1VWVha5HggElJeXF68RASNZVqskye0uUmpK/6ifJ6PlE6n5iNzuAmWm5MRqvJhqaT2tpqYaWZZl9ygA4sDRIbN3717V1dVpwoQJkWWhUEjbt2/XL37xCzU3Nys5uf1/Y6Wnpys9PT3RowJGSkryKDklM+rHu0KfHbZJdmX06nniKRTm0BLQlzk6ZG644Qa999577ZYtWLBAX/jCF3TfffddEDEAAODS4uiQyczM1BVXXNFuWb9+/ZSTk3PBcgAAcOkx4uPXAAAAHXH0OzIdeeONN+weAQAAOATvyAAAAGMRMgAAwFiEDAAAMBYhAwAAjEXIAAAAYxEyAADAWIQMAAAwFiEDAACMRcgAAABjETIAAMBYhAwAADAWIQMAAIxFyAAAAGMRMgAAwFiEDAAAMBYhAwAAjEXIAAAAYxEyAADAWIQMAAAwFiEDAACMRcgAAABjETIAAMBYhAwAADAWIQMAAIxFyAAAAGMRMgAAwFiEDAAAMBYhAwAAjEXIAAAAYxEyAADAWIQMAAAwFiEDAACMRcgAAABjETIAAMBYhAwAADAWIQMAAIxFyAAAAGMRMgAAwFiEDAAAMBYhAwAAjEXIAAAAYxEyAADAWIQMAAAwFiEDAACMRcgAAABjETIAAMBYhAwAADAWIQMAAIxFyAAAAGMRMgAAwFiEDAAAMBYhAwAAjEXIAAAAYzk6ZCorKzV58mRlZmZqyJAhmjt3rg4ePGj3WAAAwCEcHTL//d//rcWLF+vNN9/U1q1b1dLSopkzZ6qxsdHu0QAAgAOk2D1AV7Zs2dLu+oYNGzRkyBDt3btX119/vU1TAQAAp3B0yHze2bNnJUmDBg3q9D7Nzc1qbm6OXA8EAnGfCwCcpDEYTPi6Nm/eLJ/Pl5B11tTUSJKCwcSsrytOmOFSZ0zIhMNhLV26VFOnTtUVV1zR6f0qKyu1YsWKBE4GAM4Qts5LUsKCQpLe1GfnKFRUVCRsnW18vnkJX2dnrL+99kg8Y0Jm8eLF2r9/v/785z93eb/y8nKVlZVFrgcCAeXl5cV7PACwnWW1SpLc7iKlpvRPyDpDLZ8o3HxED6eN0Ojk7ISsszV0Vs3njyd0Ozuzo/UTPdV0RFa41dY5LmVGhMwPfvADvfrqq9q+fbsuu+yyLu+bnp6u9PT0BE0GAM6TlORRckpmQtblCn12aGl0crbGpQ9NyDrPn5eC54/Lk+RRWoK2szM1ocQdxkPHHB0ylmXpzjvv1Msvv6w33nhDhYWFdo8EAAAcxNEhs3jxYr3wwgv6/e9/r8zMTJ08eVKSlJ2drYyMDJunAwAAdnP098isXbtWZ8+e1bRp0zR8+PDI5cUXX7R7NAAA4ACOfkfGsiy7RwAAAA7m6HdkAAAAukLIAAAAYxEyAADAWIQMAAAwFiEDAACMRcgAAABjETIAAMBYhAwAADAWIQMAAIxFyAAAAGMRMgAAwFiEDAAAMBYhAwAAjEXIAAAAYxEyAADAWIQMAAAwFiEDAACMRcgAAABjETIAAMBYhAwAADAWIQMAAIxFyAAAAGMRMgAAwFiEDAAAMBYhAwAAjEXIAAAAYxEyAADAWIQMAAAwFiEDAACMRcgAAABjETIAAMBYhAwAADAWIQMAAIxFyAAAAGMRMgAAwFiEDAAAMBYhAwAAjEXIAAAAYxEyAADAWIQMAAAwFiEDAACMRcgAAABjETIAAMBYhAwAADAWIQMAAIxFyAAAAGMRMgAAwFiEDAAAMBYhAwAAjEXIAAAAYxEyAADAWIQMAAAwFiEDAACMZUTIrFmzRgUFBXK73ZoyZYp2795t90gAAMABHB8yL774osrKyrR8+XK98847uuqqqzRr1izV1dXZPRoAALCZ40Nm9erVWrhwoRYsWKDi4mI9/fTT8ng8+tWvfmX3aAAAwGYpdg/QlfPnz2vv3r0qLy+PLEtKStKMGTO0a9euDh/T3Nys5ubmyPWzZ89KkgKBQExnCwaDkqT3W8/oUysU0+d2mtZQQM2S0lv8Sgk3RpZ/1PrZa9uXXoPOtrUjJm9/T7azKya8BrHa1o44bfvjua2dseM1sGM7OxPv7Q9ZTWqS1NzQIE/I/v+Pfd7Bv/1beO7cuZj/O9v2fJZldX1Hy8GOHTtmSbJ27tzZbvmyZcus0tLSDh+zfPlySxIXLly4cOHCpQ9cjh492mUrOPodmWiUl5errKwscj0cDuv06dPKycmRy+WycTIzBAIB5eXl6ejRo8rKyrJ7HMQA+7TvYZ/2LezPjlmWpYaGBo0YMaLL+zk6ZHJzc5WcnKxTp061W37q1CkNGzasw8ekp6crPT293bIBAwbEa8Q+Kysrix+oPoZ92vewT/sW9ueFsrOzL3ofR5/sm5aWpokTJ6q6ujqyLBwOq7q6Wtdcc42NkwEAACdw9DsyklRWVqb58+dr0qRJKi0tVVVVlRobG7VgwQK7RwMAADZzfMh861vf0l//+lc99NBDOnnypEpKSrRlyxYNHTrU7tH6pPT0dC1fvvyCw3MwF/u072Gf9i3sz95xWdbFPtcEAADgTI4+RwYAAKArhAwAADAWIQMAAIxFyAAAAGMRMgAAwFiEDHqlpqZG06dPV3Fxsa688ko1Ntr7C9zQe8FgUKNGjdI999xj9yjopaNHj2ratGkqLi7W+PHj9Zvf/MbukdBDr776qoqKijRmzBj927/9m93jOBIfv0avfOlLX9Ijjzyi6667TqdPn1ZWVpZSUhz/9UTowo9+9CMdPnxYeXl5+vnPf273OOiFEydO6NSpUyopKdHJkyc1ceJEffjhh+rXr5/do6EbWltbVVxcrG3btik7O1sTJ07Uzp07lZOTY/dojsI7Moja+++/r9TUVF133XWSpEGDBhExhjt06JA++OADzZ492+5REAPDhw9XSUmJJGnYsGHKzc3V6dOn7R0K3bZ7926NGzdOI0eOVP/+/TV79mz913/9l91jOQ4h04dt375dN910k0aMGCGXy6VNmzZdcJ81a9aooKBAbrdbU6ZM0e7du7v9/IcOHVL//v110003acKECXr00UdjOD0+L977U5LuueceVVZWxmhiXEwi9mmbvXv3KhQKKS8vr5dTo7t6u3+PHz+ukSNHRq6PHDlSx44dS8ToRiFk+rDGxkZdddVVWrNmTYe3v/jiiyorK9Py5cv1zjvv6KqrrtKsWbNUV1cXuU9JSYmuuOKKCy7Hjx9Xa2ur/vSnP+mpp57Srl27tHXrVm3dujVRm3fJiff+/P3vf6+xY8dq7NixidqkS16892mb06dP69vf/rbWrVsX923C/4vF/kU3WLgkSLJefvnldstKS0utxYsXR66HQiFrxIgRVmVlZbeec+fOndbMmTMj11etWmWtWrUqJvOia/HYn/fff7912WWXWaNGjbJycnKsrKwsa8WKFbEcG12Ixz61LMtqamqyrrvuOuvXv/51rEZFFKLZvzt27LDmzp0buf2uu+6ynn/++YTMaxLekblEnT9/Xnv37tWMGTMiy5KSkjRjxgzt2rWrW88xefJk1dXV6cyZMwqHw9q+fbu8Xm+8RkYXYrE/KysrdfToUR05ckQ///nPtXDhQj300EPxGhkXEYt9almWbr/9dn35y1/WbbfdFq9REYXu7N/S0lLt379fx44d07lz5/Taa69p1qxZdo3sWITMJcrv9ysUCl3wW8SHDh2qkydPdus5UlJS9Oijj+r666/X+PHjNWbMGM2ZMyce4+IiYrE/4Syx2Kc7duzQiy++qE2bNqmkpEQlJSV677334jEueqg7+zclJUWPPfaYpk+frpKSEv3whz/kE0sd4CMm6JXZs2fzCZc+6Pbbb7d7BMTAF7/4RYXDYbvHQC987Wtf09e+9jW7x3A03pG5ROXm5io5OVmnTp1qt/zUqVMaNmyYTVMhWuzPvod92rexf2OHkLlEpaWlaeLEiaquro4sC4fDqq6u1jXXXGPjZIgG+7PvYZ/2bezf2OHQUh927tw5HT58OHK9pqZG+/bt06BBg5Sfn6+ysjLNnz9fkyZNUmlpqaqqqtTY2KgFCxbYODU6w/7se9infRv7N0Hs/tgU4mfbtm2WpAsu8+fPj9znySeftPLz8620tDSrtLTUevPNN+0bGF1if/Y97NO+jf2bGPyuJQAAYCzOkQEAAMYiZAAAgLEIGQAAYCxCBgAAGIuQAQAAxiJkAACAsQgZAABgLEIGAAAYi5AB4CjTpk3T0qVL7R4DgCEIGQAAYCxCBgAAGIuQAeBYZ86c0be//W0NHDhQHo9Hs2fP1qFDhyK3f/zxx7rppps0cOBA9evXT+PGjdPmzZsjj7311ls1ePBgZWRkaMyYMXrmmWfs2hQAcZJi9wAA0Jnbb79dhw4d0iuvvKKsrCzdd999uvHGG3XgwAGlpqZq8eLFOn/+vLZv365+/frpwIED6t+/vySpoqJCBw4c0Guvvabc3FwdPnxYn376qc1bBCDWCBkAjtQWMDt27NC1114rSXr++eeVl5enTZs26Zvf/KZqa2v1jW98Q1deeaUk6fLLL488vra2VldffbUmTZokSSooKEj4NgCIPw4tAXAkn8+nlJQUTZkyJbIsJydHRUVF8vl8kqQlS5bokUce0dSpU7V8+XK9++67kfsuWrRIGzduVElJie69917t3Lkz4dsAIP4IGQDG+u53v6uPPvpIt912m9577z1NmjRJTz75pCRp9uzZ+vjjj3X33Xfr+PHjuuGGG3TPPffYPDGAWCNkADiS1+tVa2ur3nrrrciyTz75RAcPHlRxcXFkWV5env7lX/5FL730kn74wx9q/fr1kdsGDx6s+fPn67nnnlNVVZXWrVuX0G0AEH+cIwPAkcaMGaObb75ZCxcu1C9/+UtlZmbq/vvv18iRI3XzzTdLkpYuXarZs2dr7NixOnPmjLZt2yav1ytJeuihhzRx4kSNGzdOzc3NevXVVyO3Aeg7eEcGgGM988wzmjhxoubMmaNrrrlGlmVp8+bNSk1NlSSFQiEtXrxYXq9XX/nKVzR27Fg99dRTkqS0tDSVl5dr/Pjxuv7665WcnKyNGzfauTkA4sBlWZZl9xAAAADR4B0ZAABgLEIGAAAYi5ABAADGImQAAICxCBkAAGAsQgYAABiLkAEAAMYiZAAAgLEIGQAAYCxCBgAAGIuQAQAAxiJkAACAsf4PQCpQwnVN9OIAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "  \n",
    "\n",
    "sns.histplot(data=df, x=\"loss\", color=\"blue\", label=\"in\", log_scale=True)\n",
    "sns.histplot(data=df, x=\"out\", color=\"red\", label=\"out\", log_scale=True)\n",
    "\n",
    "plt.legend() \n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05bd3cc7-e9f6-438b-9e72-7a67d682b590",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 106,
   "id": "3501d38f-7388-44cb-8b49-3c03c82ab8c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def imshow(img):\n",
    "    img = img\n",
    "    npimg = img\n",
    "    plt.imshow(npimg)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 107,
   "id": "94652e56-d847-4409-89fa-99e1b48e52d5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/xnp5ZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAvuUlEQVR4nO3de3Cc9X3v8c/eV9eVdZcsycgXbIwvaRwwOiSUYNeXnjIQPB1IcqYmZWCgMlNw0yTuJBBoe0TJTEKSccwfpbiZE0NCTwyFaaBgsDhJbFI7OI4TomBHYBlb8lX322r3OX+kqBUY+H1tyT9JvF8zO2Npv/7q9zzP7n71aHc/GwqCIBAAABdY2PcCAAAfTgwgAIAXDCAAgBcMIACAFwwgAIAXDCAAgBcMIACAFwwgAIAXUd8LeKdsNqujR4+qoKBAoVDI93IAAEZBEKinp0fV1dUKh9/7PGfSDaCjR4+qtrbW9zIAAOepra1NNTU173n9hA2gzZs362tf+5ra29u1dOlSffvb39bll1/+gf+voKBAkrS4rlaR95mc/92MGQnndd100xrnWklq7+x3rj3dbztjy2YyzrUZQ60kxaIR59rKsjJT7+GRrKm+te2Ie3EoZuqdSOQ413Z1njH17u0ddK4Nq8DUe2697Zes3Nx859pMJtfU+5e/es25Nq/E1vvKVZ9wrj145JCpd39vt3PtJbPnmXq3tLxpqo+G3e+fJfnut1lJioy4P64UlJaaep860+VcO9PQe2BgQJ/fePvo4/l7mZAB9P3vf18bN27Uww8/rOXLl+uhhx7S6tWr1dLSovLy8vf9v2//2S0SDjsPoGjE/cE2J+k+rCQpmRxxrk1kpuYAShr3Sdg4gOLxuHuxeQC5947Hbb1jMfd9HpZhGyUlEklTfdJQn8nYHuBiMcs+tN1WcnLdB1YiaVv3yMjwhKxDkuLG4xMzDKBk0tbbMoBycmzbmRww7ENjb0kf+DTKhLwI4etf/7puvfVWfe5zn9PChQv18MMPKzc3V//0T/80ET8OADAFjfsAGh4e1t69e7Vy5cr/+iHhsFauXKldu3a9q35oaEjd3d1jLgCA6W/cB9DJkyeVyWRUUVEx5vsVFRVqb29/V31TU5NSqdTohRcgAMCHg/f3AW3atEldXV2jl7a2Nt9LAgBcAOP+IoTS0lJFIhF1dHSM+X5HR4cqKyvfVZ9IJJRI2J7YBABMfeN+BhSPx7Vs2TLt2LFj9HvZbFY7duxQQ0PDeP84AMAUNSEvw964caPWr1+vj33sY7r88sv10EMPqa+vT5/73Ocm4scBAKagCRlAN954o06cOKF77rlH7e3t+shHPqJnn332XS9MAAB8eE1YEsKGDRu0YcOGc/7/rSczCoUCp9og6VYnSYcOH7YtJOL+5rjuE2lT66qzPCf2Xkxv5pSUTLrXV5WWmHqXV73/m4nfafHiuc61pzs7Tb17utzfFDs4UGXqnR5xfxPyqRMDpt6ZIfc3AErSb1rd0woGh21viC4vd78dnuo6Yep98Bd7nWuPnHzL1Du/0P2+GQy7v+NfknqO/MZUX17inoSRTBabeh9pdU8SSQYzTb2r8t2ffy+JuN+u+iNu9wfvr4IDAHw4MYAAAF4wgAAAXjCAAABeMIAAAF4wgAAAXjCAAABeMIAAAF4wgAAAXjCAAABeTFgUz/nKJqSQ43gczOY59z3Y6h7dIkknTvzWufadH0HxQerqZjnXlpaVmnovXHCxc23nGdun0Pb09Zrq82e4x5QosP1ONNDvHoGTzdqOfSLhHmc0MtJp6t1zxlbf8pufOdcGEds+LCq73Lk2GbVF2kTSGefaYvdkHUlSXsL9ft91vMfUO5bjHsMkSZHcQefavPKYqXf/4X7n2t2vHjD1rqmqca7trYo41w4Ouu0PzoAAAF4wgAAAXjCAAABeMIAAAF4wgAAAXjCAAABeMIAAAF4wgAAAXjCAAABeMIAAAF4wgAAAXkzaLLggk5GygVPt6TPu2WS/+M2bpnWkCt13UTiZNPU+cvy4c+1Axj1TS5I+8rGPOde2tr1h6j1oyPeSpIEh91ytpR+9zNT75JnTzrUzig2ZdJJOnD7lXDscuN1W35bJ2Orr69wzu5IFtqyxvJyQc21ne6epd3972rn2ik9eber9w6d/5Fwb2G6y6u11v29K0qGDQ8617cfaTL3rL1roXBuWLfPulZ//3Ln2oovOONcODw871XEGBADwggEEAPCCAQQA8IIBBADwggEEAPCCAQQA8IIBBADwggEEAPCCAQQA8IIBBADwYtJG8UQzcYVCbvOxsDDPuW9ukXvsiCQFkaxzbWVNnam3a1yFJJVXVJh6h+IJ59pfvHbQ1HvR4o+a6t889pZz7f4DtqikM2fc43IuTs4y9e443elcm4jmmHofP91nqi/ML3OujSVsv1eG5R5RVJyaberde9o9viU84n5fk6R589xv4/HIoKn3ySO2OCMp7lxZWZpv6pzMuO+XmSW2Y5+zNNe5tqevw7k2nHWL3+IMCADgBQMIAOAFAwgA4AUDCADgBQMIAOAFAwgA4AUDCADgBQMIAOAFAwgA4AUDCADgBQMIAODFpM2CC0UyCocCp9rKyhLnvlnZMqG6urqdawf7Mqbevb29zrU5CVt+1KmTXc61x46fNvVeGnfPvZKk4tJC59q2I7819Q5C7jlZI6o29Y7H3PPAynLdM7UkqT3mdtt+23DWvX5kwPZ75cnOk861uTm2zLueYffb+FsnfmfqvWCOewbkRbPcHyMkKT93nqk+EXO/TwwPpUy9f/TsUefakhz3fElJuuqKcufan+0dcq4dHEw71XEGBADwYtwH0Fe/+lWFQqExlwULFoz3jwEATHET8ie4Sy+9VC+88MJ//ZDopP1LHwDAkwmZDNFoVJWVlRPRGgAwTUzIc0Cvv/66qqurNXv2bH32s5/V4cOH37N2aGhI3d3dYy4AgOlv3AfQ8uXLtXXrVj377LPasmWLWltb9YlPfEI9PT1nrW9qalIqlRq91NbWjveSAACT0LgPoLVr1+pP//RPtWTJEq1evVr/9m//ps7OTv3gBz84a/2mTZvU1dU1emlraxvvJQEAJqEJf3VAUVGRLr74Yh08ePCs1ycSCSUS7p/tDgCYHib8fUC9vb06dOiQqqqqJvpHAQCmkHEfQJ///OfV3NysN954Qz/96U/1qU99SpFIRJ/+9KfH+0cBAKawcf8T3JEjR/TpT39ap06dUllZmT7+8Y9r9+7dKisrszUKMgrkFj/y1lvuzxsNp21RPPl57hE4eTm2PyXmGdJbsoYolt/Xu9dGIrabQcjxuLxtaKDfuXbEeHxy8pPOtV1nbJFDQ4YYpoEBWwxTeZ97rIkkxWLuB7Qvz3DwJQ0bbraDsQFT77wZ7veftG0XquOY+1oSxke6krKIrb7IfSd2nnG/P0hSZ5d7fU3pTFPvtrfO/tTI2Qym3aOPBkfczm3GfQA9/vjj490SADANkQUHAPCCAQQA8IIBBADwggEEAPCCAQQA8IIBBADwggEEAPCCAQQA8IIBBADwggEEAPBiwj+O4VzFYnGFw27zMRJxz22KBrZNjifizrV5ee5ZSZKUn++ek2WplaQewyfL1lRVm3qf6Dhhqu/v7XOuzQyPmHr3dfc61x7LHjH1zvS4Z3CdNuZ75aWHTfWFxe7HPx5LmXpXlBY614aTOabekZEC59oe44ch95xxr/1ti+1zxkIRW+ZdRYn7PozJdnxmFM5zro0maky9D7zmngU3MOL+ODs07FbLGRAAwAsGEADACwYQAMALBhAAwAsGEADACwYQAMALBhAAwAsGEADACwYQAMALBhAAwItJG8UTjUYUDrvFOaRS7tEWhYW2uJz+fvdIjjNnDNkgknJy3GNN2tvbTb1PnTrlXJuXm2vq3fbmYVP9wID7PjxxosPUuyDlHlETZDOm3n2nu5xrB9vd97ckZROBqT4n4R71Uxy3xfzE+04714bDMVPvUMz9thWL2h6OoiH3tWSzZabeuQW2tWQzCefaglSJqfdg2j0CxxoJtXDhCufaX/z6t861mcyQUx1nQAAALxhAAAAvGEAAAC8YQAAALxhAAAAvGEAAAC8YQAAALxhAAAAvGEAAAC8YQAAALxhAAAAvJm0W3MDgoMIht/nY1tbm3Le4pMi0jq4u9zyw/j73zDNJyslJOtdms7bssGjEPT8qJ+G+DkkKmaql9MiIe+2wLccsoqxzbV6BLQewJ+1+PHsCt+yrt40M2vZicNQ9Z7C3131/S1KxIR8xmXTPL5SkWNL9dtgzYrv/9PS61weBbX8XFxeb6udftMS5dk5dtan3wdY3nGtHRmxZcMq459Llxtzz7iKOd0vOgAAAXjCAAABeMIAAAF4wgAAAXjCAAABeMIAAAF4wgAAAXjCAAABeMIAAAF4wgAAAXjCAAABeTNosuExmRFnHLLi+Lvf8o4EBW1ZSOOw+o9NpWwbX8HDPhKxDkmIx90MbZDKm3tFozFSflXuOXdYYNNfXO+hcmzzda+odjhU61wbRPlPvxJDxrjfgvmOyWdvxHJT77Xawu9vUuy990rm2u9e2D6MJ92yyyqqZpt6tb7jnS0pSeVHKuTbXsG5Jyst3z+rrG7Ddxg++cci59kyn+7Ecdsx05AwIAOCFeQC9/PLLuvbaa1VdXa1QKKQnn3xyzPVBEOiee+5RVVWVcnJytHLlSr3++uvjtV4AwDRhHkB9fX1aunSpNm/efNbrH3zwQX3rW9/Sww8/rFdeeUV5eXlavXq1Bgfd/1QCAJj+zM8BrV27VmvXrj3rdUEQ6KGHHtKXv/xlXXfddZKk7373u6qoqNCTTz6pm2666fxWCwCYNsb1OaDW1la1t7dr5cqVo99LpVJavny5du3addb/MzQ0pO7u7jEXAMD0N64DqL29XZJUUVEx5vsVFRWj171TU1OTUqnU6KW2tnY8lwQAmKS8vwpu06ZN6urqGr1YPl4bADB1jesAqqyslCR1dHSM+X5HR8fode+USCRUWFg45gIAmP7GdQDV19ersrJSO3bsGP1ed3e3XnnlFTU0NIznjwIATHHmV8H19vbq4MGDo1+3trZq3759Ki4uVl1dne666y793d/9nebNm6f6+np95StfUXV1ta6//vrxXDcAYIozD6A9e/bok5/85OjXGzdulCStX79eW7du1Re+8AX19fXptttuU2dnpz7+8Y/r2WefVTKZNP2cIMhIjhEuOckcQ1/TMpTJZJ1rs+6l/9nbPTIlErH1ltx7RyPGKB5jlEg47B4jkw5sOzETuEV+SNJg5xlT72zy7H82Ppv8uHsUiyQp3fHBNf/NjFz3u+pQ1n2fSFJbe6d772HbH02iQdq5NqfAPXJGkmpnusfrfOQP/sDU2/rm+fYTZ3+R1dkc6Sg19a6vc9/Okf4BU2/F3B8QC0uKnWuHhoac6swD6Oqrr1bwPo/ioVBI999/v+6//35rawDAh4j3V8EBAD6cGEAAAC8YQAAALxhAAAAvGEAAAC8YQAAALxhAAAAvGEAAAC8YQAAALxhAAAAvzFE8F8rISEahkFsu2NCQe5ZZ2BqqZgiPs2S7/b61e++w8VeFiGE78wsKTL1dc57eFhjy3RKGXD9JiuXFnWvzM/2m3hXJHufaqoR75pkkLUj1murnxd33y5EB930iSf+n3f34HMqfYepdEnPPAYzGbQ9HpSUlzrWzZ9ebei9ffrmp/vHHvudc++bhw6bec+tnOdeWF9mOz8Cg+325d2jEuTYaGXSq4wwIAOAFAwgA4AUDCADgBQMIAOAFAwgA4AUDCADgBQMIAOAFAwgA4AUDCADgBQMIAODFpI3iscTUJJNJ51prjIzck0TM4nH3yBTLNkpSJuMem6GQbSNThYWm+vSAWyyHJGXS7sddkirkvl/+sMwWUfM/qtwjimbnG/a3pLzWDlt9l3uM0FDUtp2dqVzn2v8btkVZDaTdY37SaVuc0fWfut659n/+yZ+YeoeN2VdvvvE759r/t/MlU++eHvdIqOoSWxRP0nA8M9lh59qQYyoZZ0AAAC8YQAAALxhAAAAvGEAAAC8YQAAALxhAAAAvGEAAAC8YQAAALxhAAAAvGEAAAC8YQAAALyZtFlwkElEo5DYfIxH3PKNs1j2bSrJl0llZsuCsLOvu6u6yNc/NM5UXJN3rk4MDpt7/K93mXLtmxHYsZyQWOdfm/8FKU+/2DsewrP/UefQXzrU56dOm3otC7tl++3NsmYQDF811rn2r/Zip98kTJ51rrVmKlscUSVqyZIlz7c92/dTU+/Rp9+NZlrLdNwO53ye6O93396Bj5iZnQAAALxhAAAAvGEAAAC8YQAAALxhAAAAvGEAAAC8YQAAALxhAAAAvGEAAAC8YQAAALyZtFE88HneO4gkC93gdayRHNOq+izIZW7xKOp2ekFpJCkfcf7eIJmw3gyBjizPKTbjv85m5uabeHxnoda6ddard1Lsn53Xn2s4O98gZSUrGbfswXmqoHbJFDi2U+z7/aEmVqffusHukTX5evqn3Sy+95Fx72fLLTb3fOvqWqf6Zf/1X59qcnBxT7/x89/0ybIwaG0m7ReZI0vBwt3NtenjYqY4zIACAFwwgAIAX5gH08ssv69prr1V1dbVCoZCefPLJMdfffPPNCoVCYy5r1qwZr/UCAKYJ8wDq6+vT0qVLtXnz5vesWbNmjY4dOzZ6eeyxx85rkQCA6cf8IoS1a9dq7dq171uTSCRUWVl5zosCAEx/E/Ic0M6dO1VeXq758+frjjvu0KlTp96zdmhoSN3d3WMuAIDpb9wH0Jo1a/Td735XO3bs0D/8wz+oublZa9eufc+XKDc1NSmVSo1eamtrx3tJAIBJaNzfB3TTTTeN/nvx4sVasmSJ5syZo507d2rFihXvqt+0aZM2btw4+nV3dzdDCAA+BCb8ZdizZ89WaWmpDh48eNbrE4mECgsLx1wAANPfhA+gI0eO6NSpU6qqsr2DGgAwvZn/BNfb2zvmbKa1tVX79u1TcXGxiouLdd9992ndunWqrKzUoUOH9IUvfEFz587V6tWrx3XhAICpzTyA9uzZo09+8pOjX7/9/M369eu1ZcsW7d+/X//8z/+szs5OVVdXa9WqVfrbv/1bJRIJ088Jy/30LGTIm8ox5JJJUjwec67NZm1ZcKdPdzrXZtK23hHHHD1JkrF3Nmyr7x8YcK4NyspNvQ9Gq51r64Zs687p73SuDSfeMPUenmfLa8vMme9c2/pGj6n3wQ73rLFW4z783cnfOdfGjY8Rw1m3vDFJavrff2/qffL4SVO95H48Fy6YZ+qcX1DgXBuL2x7fsobsuJrqGufagcFBpzrzALr66qsVBO+9s5977jlrSwDAhxBZcAAALxhAAAAvGEAAAC8YQAAALxhAAAAvGEAAAC8YQAAALxhAAAAvGEAAAC8YQAAAL8b984DGSywcVTjsNh9HRkac+8Yjtk1ORA1ZcIFtnkdCIefaZNyWk1VcVORcmxN3z9KTpJBh3ZI0nHbP7BowruW1spnuvVs7Tb1Let3XPXLU1vt0xrYPT/2u37m27aR7vpckvTXsltslSb3JLlPviOF4dnafsfWOuPc+fcqW7VaUP8NUX1pa4lxrzaOMGfIo49G4qXck5n47DGK57n3DbvmPnAEBALxgAAEAvGAAAQC8YAABALxgAAEAvGAAAQC8YAABALxgAAEAvGAAAQC8YAABALyYtFE8Fq6RPZIUki0CJWTobQuRkQytlV+QZ+o9c2aFc20kFJh6W6N4gqx7NMzxM6dMvffl1DnX7sm6R4lIUqS717k2+5NOU+9oXpmpfrDT/fgHEdvxHMxxPz7DYffYK0kKhtxvK4mYe+SMJEUNUTxxY5RVXtIWl5OX496/qDDf1Dtp2C9x4z7MGB4Pg8D9djWSyTjVcQYEAPCCAQQA8IIBBADwggEEAPCCAQQA8IIBBADwggEEAPCCAQQA8IIBBADwggEEAPCCAQQA8GLSZsGFQiHnzLGkIbcpkDEna2DIuTaRtOUwJZNx59qLLppp6n39ddc61545cdrU+7XXXjPV5+W6Z7C9/rtWU++u3n7n2iBly1873eN+7HNO2bLD4kHaVJ+OumVrSVIwaMtrG+h37x1NuN9mJSlmyA+LWsIRJYUNOWbGeDyFAvd9IkmJqPvaZ6QKTb1zDI8TxphGU66jpdY1n5MzIACAFwwgAIAXDCAAgBcMIACAFwwgAIAXDCAAgBcMIACAFwwgAIAXDCAAgBcMIACAF5M2iicIAgWOMR7ZbNa5ryW2R5LC4Yh7bcSWg2GJ4pkzZ5apdzzu/rtFQUGeqfe8eXNM9SdPnXSunVlZbur95u8OO9fm5tq2syCWcK7NGGJ7JKm/+5SpPpR0z5LJDrjfZiVJI+63w7Dc94kkKeIeaRMyROtIUsRwfwsZo3jykrbtLC8tca4tKsw39Y7I/fEtZIwas8TrBIbHWTk+dnMGBADwwjSAmpqadNlll6mgoEDl5eW6/vrr1dLSMqZmcHBQjY2NKikpUX5+vtatW6eOjo5xXTQAYOozDaDm5mY1NjZq9+7dev7555VOp7Vq1Sr19fWN1tx99916+umn9cQTT6i5uVlHjx7VDTfcMO4LBwBMbabngJ599tkxX2/dulXl5eXau3evrrrqKnV1demRRx7Rtm3bdM0110iSHn30UV1yySXavXu3rrjiivFbOQBgSjuv54C6urokScXFxZKkvXv3Kp1Oa+XKlaM1CxYsUF1dnXbt2nXWHkNDQ+ru7h5zAQBMf+c8gLLZrO666y5deeWVWrRokSSpvb1d8XhcRUVFY2orKirU3t5+1j5NTU1KpVKjl9ra2nNdEgBgCjnnAdTY2KgDBw7o8ccfP68FbNq0SV1dXaOXtra28+oHAJgazul9QBs2bNAzzzyjl19+WTU1NaPfr6ys1PDwsDo7O8ecBXV0dKiysvKsvRKJhBIJ43sLAABTnukMKAgCbdiwQdu3b9eLL76o+vr6MdcvW7ZMsVhMO3bsGP1eS0uLDh8+rIaGhvFZMQBgWjCdATU2Nmrbtm166qmnVFBQMPq8TiqVUk5OjlKplG655RZt3LhRxcXFKiws1J133qmGhgZeAQcAGMM0gLZs2SJJuvrqq8d8/9FHH9XNN98sSfrGN76hcDisdevWaWhoSKtXr9Z3vvOdcVksAGD6MA0gl2y2ZDKpzZs3a/Pmzee8KEkqLStTJOKWaXXs6FHnvp2dnaZ1FBamnGtz83JMvUNh9xymgsJcU++8fPfMuxmFRabeXd1nTPXptHtOWm6BbTtHDNlXvT39pt45Cfd92DcyaOo9lLU975mbdl9LoGFT7yDqntcWThjywCQlwu45c5ZcMkkKh9yfQZjxjlfmfpBZdbZX4xYVFTrXxqK2rL54zPIwbXtdWTbsfv8JLDlzjo9tZMEBALxgAAEAvGAAAQC8YAABALxgAAEAvGAAAQC8YAABALxgAAEAvGAAAQC8YAABALw4p49juBB6+voUDrtFVkQNURVFSffIDEnq7upxru3v7zX1Lky5x6sUF5WYepeVlDnXhmWLBlmyZImpfmBgwLm27cgRU+9ojnukzWCv+zokKSb3iJq8wnxT7/yQe0SNJGXT7jEow0GfqXc4POJcGzXER0mSsu7rjkRtD0czimY4186pn23qXVhoi9WKGdZeUGB7DIrF3B8nhkdsUUlByP02rrDhfMWxljMgAIAXDCAAgBcMIACAFwwgAIAXDCAAgBcMIACAFwwgAIAXDCAAgBcMIACAFwwgAIAXDCAAgBeTNguuu7dX4ZDbfCybkXLue8Xly0zriMfcs8ZafvtbU+/ftbrXDw4MmXqXlVU41w4NDJp6x2IxU/38+Rc717a0tJh6R+SefZVK2fLaNOKek5Ubt2WHBcZf/YYiaUO1rXkm454FGBgy6SQpGnd/iKkoLzf1rqqqcq7NSbrfjyUpHrdl9VXPnOlcm0q5Z9iZhdxz/SQpMOQdWo68ay1nQAAALxhAAAAvGEAAAC8YQAAALxhAAAAvGEAAAC8YQAAALxhAAAAvGEAAAC8YQAAALyZtFE8yGlU47DYfZ9XWOPddtPAS0zpmGiI2rlhui/l5dd/PnWvDtgQUdZ4+41ybk2OLkRkYGDDV9/f3O9dGou6xMJKUybpH8STitt5B1n2np0csUTlSJuMegSJJI1n3iJVQ4L5PJCkRdX8YyM/PM/WuqnSPy6msrDT1tkgYo3guumiWqb6iwj36KhQKmXqPjNjidaYSzoAAAF4wgAAAXjCAAABeMIAAAF4wgAAAXjCAAABeMIAAAF4wgAAAXjCAAABeMIAAAF4wgAAAXkzaLLhYNOycBVdd5Z7DVFNty5sqyHfPkJqRsmWqVa5a6Vz7xhtvmHr3dfc61/b09Jh6nzhxwlS/a9cu59qwMSfLUp1O2/Lawhn3LLhIxJYzFzH+6hc15LUlk/mm3oWFhc61qaIiU+9kMulcm8kMm3oXFxc7185fMN/Uu7S01FRvudlacwCDwBgEOYVwBgQA8MI0gJqamnTZZZepoKBA5eXluv7669XS0jKm5uqrr1YoFBpzuf3228d10QCAqc80gJqbm9XY2Kjdu3fr+eefVzqd1qpVq9TX1zem7tZbb9WxY8dGLw8++OC4LhoAMPWZngN69tlnx3y9detWlZeXa+/evbrqqqtGv5+bmzuhn+0BAJj6zus5oK6uLknvfjLwe9/7nkpLS7Vo0SJt2rTpfT+QbGhoSN3d3WMuAIDp75xfBZfNZnXXXXfpyiuv1KJFi0a//5nPfEazZs1SdXW19u/fry9+8YtqaWnRD3/4w7P2aWpq0n333XeuywAATFHnPIAaGxt14MAB/fjHPx7z/dtuu23034sXL1ZVVZVWrFihQ4cOac6cOe/qs2nTJm3cuHH06+7ubtXW1p7rsgAAU8Q5DaANGzbomWee0csvv6yampr3rV2+fLkk6eDBg2cdQIlEQomE7fPaAQBTn2kABUGgO++8U9u3b9fOnTtVX1//gf9n3759kqSqqqpzWiAAYHoyDaDGxkZt27ZNTz31lAoKCtTe3i5JSqVSysnJ0aFDh7Rt2zb98R//sUpKSrR//37dfffduuqqq7RkyZIJ2QAAwNRkGkBbtmyR9Ps3m/53jz76qG6++WbF43G98MILeuihh9TX16fa2lqtW7dOX/7yl8dtwQCA6cH8J7j3U1tbq+bm5vNa0NsioUDhkFsGUmWFe25TQZ4try0acc9hisZMrRWPuq9lRqrI1PvUydPOtfMumWfqnZeXZ6qfPXu2c+2r//knW1fZIOtc29frno8nSak890w16z5JxGzvgIjH3bPmcnJzTb1jMfcbbtT4xo1w2P3+U1rmnu0mSZdeeqlzbUlJiam3LWVwYo2MjPhewoQhCw4A4AUDCADgBQMIAOAFAwgA4AUDCADgBQMIAOAFAwgA4AUDCADgBQMIAOAFAwgA4MU5fx7QRCsuKlQk4hY/MvuiWc5941HrJqedK2MRW3xHPOa+lqqKClPvAwd+5Vx7pueMqffChQtN9cPDw861Q4ODpt6RsPvvUIOZjKl3xlAfZN0jgX7f21g/4h5pk0m7729JShhuh8Uziky95y2Y71xbV1dn6p2X6x5/ZIlskqRsxn1/Sx8cU/bfhUK2x4lwyP02HpLtNm4RZN230bWWMyAAgBcMIACAFwwgAIAXDCAAgBcMIACAFwwgAIAXDCAAgBcMIACAFwwgAIAXDCAAgBcMIACAF5M2C+76a/9YyUTCqbaqosy5ryFWSZIUjSadayPhmK25Yf4XFBaYOn/sY8uca5/b8byp9949e031p0+ddq7Njtgyu4pyc51rF198sal3Iu52+5MkGbLAJCmZNPSWlJPjfjtMpVKm3pWVlRNSK0m5ee55bVYDA+65gVFjBqQ1r82Wk2a7jVvWEjE+wEUCw3YatjFMFhwAYDJjAAEAvGAAAQC8YAABALxgAAEAvGAAAQC8YAABALxgAAEAvGAAAQC8YAABALyYtFE8CxcuUJ5jzEo47D5HY/G4aR2xmHu8TiQSMfW2xINYo0Tq59Q7166NrzX13r17t6m+LdHmXGvZ35K0aMEC59pLL73U1NtyPC23QUmKGY9n2LAW6z6MG+4TgTFyaGBgwLl2cNA9WscskzGVx6K2fRiOuh+f7IhtHwYaMVVbWBKHAkNv11rOgAAAXjCAAABeMIAAAF4wgAAAXjCAAABeMIAAAF4wgAAAXjCAAABeMIAAAF4wgAAAXjCAAABeTNosuCAkZR1zisIRwxwNG8KPJMmQ8RUyZsGFDPlRIcs2ypZNNnfePFPvmpoaU31XV5dzrTUjLdcxL1A6l0PvfnysOYBW2ax7lpktDcz2H7LZrKl1bn6ec208mTT1Hk4PO9dmM7Z1Z4170bJfMsbeI4ZjnzZm3o0Y1m3p7VrLGRAAwAvTANqyZYuWLFmiwsJCFRYWqqGhQT/60Y9Grx8cHFRjY6NKSkqUn5+vdevWqaOjY9wXDQCY+kwDqKamRg888ID27t2rPXv26JprrtF1112nX/3qV5Kku+++W08//bSeeOIJNTc36+jRo7rhhhsmZOEAgKnN9Af3a6+9dszXf//3f68tW7Zo9+7dqqmp0SOPPKJt27bpmmuukSQ9+uijuuSSS7R7925dccUV47dqAMCUd87PAWUyGT3++OPq6+tTQ0OD9u7dq3Q6rZUrV47WLFiwQHV1ddq1a9d79hkaGlJ3d/eYCwBg+jMPoF/+8pfKz89XIpHQ7bffru3bt2vhwoVqb29XPB5XUVHRmPqKigq1t7e/Z7+mpialUqnRS21trXkjAABTj3kAzZ8/X/v27dMrr7yiO+64Q+vXr9evf/3rc17Apk2b1NXVNXppa3P/+GYAwNRlfh9QPB7X3LlzJUnLli3Tf/zHf+ib3/ymbrzxRg0PD6uzs3PMWVBHR4cqKyvfs18ikVAikbCvHAAwpZ33+4Cy2ayGhoa0bNkyxWIx7dixY/S6lpYWHT58WA0NDef7YwAA04zpDGjTpk1au3at6urq1NPTo23btmnnzp167rnnlEqldMstt2jjxo0qLi5WYWGh7rzzTjU0NPAKOADAu5gG0PHjx/Vnf/ZnOnbsmFKplJYsWaLnnntOf/RHfyRJ+sY3vqFwOKx169ZpaGhIq1ev1ne+851zWlhxWZny89xiPCwJK9GY7a+OsWjMuTZsjGOJWGJ+DLWSNDzsHlMyNDho6m2NnZkxY4ap3sIUUWPM4snIPaYksCW92G8r8bhzbdoQUSPZYpuygS3qxRKBE42739ck2/FJJG33e+vjRGbEfb9kjHFGOXnucVMjhnVItnWPjIw41/b19zvVmfbyI4888r7XJ5NJbd68WZs3b7a0BQB8CJEFBwDwggEEAPCCAQQA8IIBBADwggEEAPCCAQQA8IIBBADwggEEAPCCAQQA8MKchj3RgiCQ5B7lIE1wFE/EEsVjm+eTJYpneDht6m2JbpH+65hOBEsUT8gYxWMRDtn2SThsjOIx3LbSI8bjaVi7tbcpiidqu28ODw8510YiUzeKJ5txj8AxR/EYjo8liqf/Px+/P+i+Hwom8tHhHBw5coQPpQOAaaCtrU01NTXvef2kG0DZbFZHjx5VQUGBQqH/+o21u7tbtbW1amtrU2FhoccVTiy2c/r4MGyjxHZON+OxnUEQqKenR9XV1e/7F5NJ9ye4cDj8vhOzsLBwWh/8t7Gd08eHYRsltnO6Od/tTKVSH1jDixAAAF4wgAAAXkyZAZRIJHTvvfcqkUj4XsqEYjunjw/DNkps53RzIbdz0r0IAQDw4TBlzoAAANMLAwgA4AUDCADgBQMIAODFlBlAmzdv1kUXXaRkMqnly5frZz/7me8ljauvfvWrCoVCYy4LFizwvazz8vLLL+vaa69VdXW1QqGQnnzyyTHXB0Gge+65R1VVVcrJydHKlSv1+uuv+1nsefig7bz55pvfdWzXrFnjZ7HnqKmpSZdddpkKCgpUXl6u66+/Xi0tLWNqBgcH1djYqJKSEuXn52vdunXq6OjwtOJz47KdV1999buO5+233+5pxedmy5YtWrJkyeibTRsaGvSjH/1o9PoLdSynxAD6/ve/r40bN+ree+/Vz3/+cy1dulSrV6/W8ePHfS9tXF166aU6duzY6OXHP/6x7yWdl76+Pi1dulSbN28+6/UPPvigvvWtb+nhhx/WK6+8ory8PK1evVqDg4MXeKXn54O2U5LWrFkz5tg+9thjF3CF56+5uVmNjY3avXu3nn/+eaXTaa1atUp9fX2jNXfffbeefvppPfHEE2pubtbRo0d1ww03eFy1nct2StKtt9465ng++OCDnlZ8bmpqavTAAw9o79692rNnj6655hpdd911+tWvfiXpAh7LYAq4/PLLg8bGxtGvM5lMUF1dHTQ1NXlc1fi69957g6VLl/pexoSRFGzfvn3062w2G1RWVgZf+9rXRr/X2dkZJBKJ4LHHHvOwwvHxzu0MgiBYv359cN1113lZz0Q5fvx4IClobm4OguD3xy4WiwVPPPHEaM1rr70WSAp27drla5nn7Z3bGQRB8Id/+IfBX/7lX/pb1ASZMWNG8I//+I8X9FhO+jOg4eFh7d27VytXrhz9Xjgc1sqVK7Vr1y6PKxt/r7/+uqqrqzV79mx99rOf1eHDh30vacK0traqvb19zHFNpVJavnz5tDuukrRz506Vl5dr/vz5uuOOO3Tq1CnfSzovXV1dkqTi4mJJ0t69e5VOp8cczwULFqiurm5KH893bufbvve976m0tFSLFi3Spk2bRj9+YCrKZDJ6/PHH1dfXp4aGhgt6LCddGOk7nTx5UplMRhUVFWO+X1FRod/85jeeVjX+li9frq1bt2r+/Pk6duyY7rvvPn3iE5/QgQMHVFBQ4Ht54669vV2Sznpc375uulizZo1uuOEG1dfX69ChQ/qbv/kbrV27Vrt27VIkYvtcoMkgm83qrrvu0pVXXqlFixZJ+v3xjMfjKioqGlM7lY/n2bZTkj7zmc9o1qxZqq6u1v79+/XFL35RLS0t+uEPf+hxtXa//OUv1dDQoMHBQeXn52v79u1auHCh9u3bd8GO5aQfQB8Wa9euHf33kiVLtHz5cs2aNUs/+MEPdMstt3hcGc7XTTfdNPrvxYsXa8mSJZozZ4527typFStWeFzZuWlsbNSBAwem/HOUH+S9tvO2224b/ffixYtVVVWlFStW6NChQ5ozZ86FXuY5mz9/vvbt26euri79y7/8i9avX6/m5uYLuoZJ/ye40tJSRSKRd70Co6OjQ5WVlZ5WNfGKiop08cUX6+DBg76XMiHePnYftuMqSbNnz1ZpaemUPLYbNmzQM888o5deemnMx6ZUVlZqeHhYnZ2dY+qn6vF8r+08m+XLl0vSlDue8Xhcc+fO1bJly9TU1KSlS5fqm9/85gU9lpN+AMXjcS1btkw7duwY/V42m9WOHTvU0NDgcWUTq7e3V4cOHVJVVZXvpUyI+vp6VVZWjjmu3d3deuWVV6b1cZV+/6m/p06dmlLHNggCbdiwQdu3b9eLL76o+vr6MdcvW7ZMsVhszPFsaWnR4cOHp9Tx/KDtPJt9+/ZJ0pQ6nmeTzWY1NDR0YY/luL6kYYI8/vjjQSKRCLZu3Rr8+te/Dm677bagqKgoaG9v9720cfNXf/VXwc6dO4PW1tbgJz/5SbBy5cqgtLQ0OH78uO+lnbOenp7g1VdfDV599dVAUvD1r389ePXVV4M333wzCIIgeOCBB4KioqLgqaeeCvbv3x9cd911QX19fTAwMOB55Tbvt509PT3B5z//+WDXrl1Ba2tr8MILLwQf/ehHg3nz5gWDg4O+l+7sjjvuCFKpVLBz587g2LFjo5f+/v7Rmttvvz2oq6sLXnzxxWDPnj1BQ0ND0NDQ4HHVdh+0nQcPHgzuv//+YM+ePUFra2vw1FNPBbNnzw6uuuoqzyu3+dKXvhQ0NzcHra2twf79+4MvfelLQSgUCv793/89CIILdyynxAAKgiD49re/HdTV1QXxeDy4/PLLg927d/te0ri68cYbg6qqqiAejwczZ84MbrzxxuDgwYO+l3VeXnrppUDSuy7r168PguD3L8X+yle+ElRUVASJRCJYsWJF0NLS4nfR5+D9trO/vz9YtWpVUFZWFsRisWDWrFnBrbfeOuV+eTrb9kkKHn300dGagYGB4C/+4i+CGTNmBLm5ucGnPvWp4NixY/4WfQ4+aDsPHz4cXHXVVUFxcXGQSCSCuXPnBn/9138ddHV1+V240Z//+Z8Hs2bNCuLxeFBWVhasWLFidPgEwYU7lnwcAwDAi0n/HBAAYHpiAAEAvGAAAQC8YAABALxgAAEAvGAAAQC8YAABALxgAAEAvGAAAQC8YAABALxgAAEAvGAAAQC8+P97tlGt2lCeiQAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "imshow(X_data[pri_risk_rank[idx]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "id": "9a280b28-71e2-4ac0-be7a-cab54c5af113",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "6"
      ]
     },
     "execution_count": 100,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Y_data[pri_risk_rank[idx]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63075866-0819-4386-8e3d-ee79a873b1fd",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "opacus",
   "language": "python",
   "name": "opacus"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
